vellum-ai 0.14.39__py3-none-any.whl → 0.14.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
  3. vellum/workflows/nodes/core/try_node/node.py +1 -2
  4. vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
  5. vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
  6. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
  7. vellum/workflows/nodes/utils.py +4 -2
  8. vellum/workflows/outputs/base.py +3 -2
  9. vellum/workflows/references/output.py +20 -0
  10. vellum/workflows/state/base.py +36 -14
  11. vellum/workflows/state/tests/test_state.py +5 -2
  12. vellum/workflows/types/stack.py +11 -0
  13. vellum/workflows/workflows/base.py +5 -0
  14. vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
  15. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
  16. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +67 -62
  17. vellum_cli/push.py +0 -2
  18. vellum_ee/workflows/display/base.py +14 -1
  19. vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
  20. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
  21. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
  22. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
  23. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
  24. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
  25. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
  26. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
  27. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
  28. vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
  29. vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
  30. vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
  31. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  32. vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
  33. vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
  34. vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
  35. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
  36. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
  37. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
  38. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
  39. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
  40. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
  41. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
  42. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
  43. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
  44. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
  45. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
  46. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
  47. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
  50. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
  51. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
  52. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
  56. vellum_ee/workflows/display/types.py +5 -4
  57. vellum_ee/workflows/display/utils/exceptions.py +7 -0
  58. vellum_ee/workflows/display/utils/registry.py +37 -0
  59. vellum_ee/workflows/display/utils/vellum.py +2 -1
  60. vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
  61. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
  62. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
  63. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
  64. vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
  65. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
  66. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
  67. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.39",
21
+ "X-Fern-SDK-Version": "0.14.40",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -4,11 +4,13 @@ from typing import Optional
4
4
 
5
5
  from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
6
6
  from vellum.core.pydantic_utilities import UniversalBaseModel
7
+ from vellum.workflows.constants import undefined
7
8
  from vellum.workflows.descriptors.tests.test_utils import FixtureState
8
9
  from vellum.workflows.inputs.base import BaseInputs
9
10
  from vellum.workflows.nodes import FinalOutputNode
10
11
  from vellum.workflows.nodes.bases.base import BaseNode
11
12
  from vellum.workflows.outputs.base import BaseOutputs
13
+ from vellum.workflows.references.output import OutputReference
12
14
  from vellum.workflows.state.base import BaseState, StateMeta
13
15
 
14
16
 
@@ -259,3 +261,25 @@ def test_resolve_value__for_falsy_values(falsy_value, expected_type):
259
261
 
260
262
  # THEN the output has the correct value
261
263
  assert falsy_output.value == falsy_value
264
+
265
+
266
+ def test_node_outputs__inherits_instance():
267
+ # GIVEN a node with two outputs, one with and one without a default instance
268
+ class MyNode(BaseNode):
269
+ class Outputs:
270
+ foo: str
271
+ bar = "hello"
272
+
273
+ # AND a node that inherits from MyNode
274
+ class InheritedNode(MyNode):
275
+ pass
276
+
277
+ # WHEN we reference each output
278
+ foo_output = InheritedNode.Outputs.foo
279
+ bar_output = InheritedNode.Outputs.bar
280
+
281
+ # THEN the output reference instances are correct
282
+ assert isinstance(foo_output, OutputReference)
283
+ assert foo_output.instance is undefined
284
+ assert isinstance(bar_output, OutputReference)
285
+ assert bar_output.instance == "hello"
@@ -4,7 +4,6 @@ from vellum.workflows.context import execution_context, get_parent_context
4
4
  from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
5
5
  from vellum.workflows.events.workflow import is_workflow_event
6
6
  from vellum.workflows.exceptions import NodeException
7
- from vellum.workflows.nodes.bases import BaseNode
8
7
  from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
9
8
  from vellum.workflows.nodes.utils import create_adornment
10
9
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
@@ -24,7 +23,7 @@ class TryNode(BaseAdornmentNode[StateType], Generic[StateType]):
24
23
 
25
24
  on_error_code: Optional[WorkflowErrorCode] = None
26
25
 
27
- class Outputs(BaseNode.Outputs):
26
+ class Outputs(BaseAdornmentNode.Outputs):
28
27
  error: Optional[WorkflowError] = None
29
28
 
30
29
  def run(self) -> Iterator[BaseOutput]:
@@ -0,0 +1,3 @@
1
+ from vellum.workflows.nodes.experimental.tool_calling_node.node import ToolCallingNode
2
+
3
+ __all__ = ["ToolCallingNode"]
@@ -0,0 +1,147 @@
1
+ from collections.abc import Callable
2
+ from typing import Any, ClassVar, Dict, List, Optional, cast
3
+
4
+ from vellum import ChatMessage, FunctionDefinition, PromptBlock
5
+ from vellum.client.types.chat_message_request import ChatMessageRequest
6
+ from vellum.workflows.context import execution_context, get_parent_context
7
+ from vellum.workflows.errors.types import WorkflowErrorCode
8
+ from vellum.workflows.exceptions import NodeException
9
+ from vellum.workflows.graph.graph import Graph
10
+ from vellum.workflows.inputs.base import BaseInputs
11
+ from vellum.workflows.nodes.bases import BaseNode
12
+ from vellum.workflows.nodes.experimental.tool_calling_node.utils import (
13
+ ToolRouterNode,
14
+ create_function_node,
15
+ create_tool_router_node,
16
+ )
17
+ from vellum.workflows.outputs.base import BaseOutputs
18
+ from vellum.workflows.state.base import BaseState
19
+ from vellum.workflows.state.context import WorkflowContext
20
+ from vellum.workflows.types.core import EntityInputsInterface
21
+ from vellum.workflows.workflows.base import BaseWorkflow
22
+
23
+
24
+ class ToolCallingNode(BaseNode):
25
+ """
26
+ A Node that dynamically invokes the provided functions to the underlying Prompt
27
+
28
+ Attributes:
29
+ ml_model: str - The model to use for tool calling (e.g., "gpt-4o-mini")
30
+ blocks: List[PromptBlock] - The prompt blocks to use (same format as InlinePromptNode)
31
+ functions: List[FunctionDefinition] - The functions that can be called
32
+ function_callables: List[Callable] - The callables that can be called
33
+ prompt_inputs: Optional[EntityInputsInterface] - Mapping of input variable names to values
34
+ """
35
+
36
+ ml_model: ClassVar[str] = "gpt-4o-mini"
37
+ blocks: ClassVar[List[PromptBlock]] = []
38
+ functions: ClassVar[List[FunctionDefinition]] = []
39
+ function_callables: ClassVar[Dict[str, Callable[..., Any]]] = {}
40
+ prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
41
+ # TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
42
+ max_tool_calls: ClassVar[int] = 1
43
+
44
+ class Outputs(BaseOutputs):
45
+ """
46
+ The outputs of the ToolCallingNode.
47
+
48
+ text: The final text response after tool calling
49
+ chat_history: The complete chat history including tool calls
50
+ """
51
+
52
+ text: str = ""
53
+ chat_history: List[ChatMessage] = []
54
+
55
+ def run(self) -> Outputs:
56
+ """
57
+ Run the tool calling workflow.
58
+
59
+ This dynamically builds a graph with router and function nodes,
60
+ then executes the workflow.
61
+ """
62
+ self._validate_functions()
63
+
64
+ initial_chat_history = []
65
+
66
+ # Extract chat history from prompt inputs if available
67
+ if self.prompt_inputs and "chat_history" in self.prompt_inputs:
68
+ chat_history_input = self.prompt_inputs["chat_history"]
69
+ if isinstance(chat_history_input, list) and all(
70
+ isinstance(msg, (ChatMessage, ChatMessageRequest)) for msg in chat_history_input
71
+ ):
72
+ initial_chat_history = [
73
+ msg if isinstance(msg, ChatMessage) else ChatMessage.model_validate(msg.model_dump())
74
+ for msg in chat_history_input
75
+ ]
76
+
77
+ self._build_graph()
78
+
79
+ with execution_context(parent_context=get_parent_context()):
80
+
81
+ class ToolCallingState(BaseState):
82
+ chat_history: List[ChatMessage] = initial_chat_history
83
+
84
+ class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
85
+ graph = self._graph
86
+
87
+ class Outputs(BaseWorkflow.Outputs):
88
+ text: str = ToolRouterNode.Outputs.text
89
+ chat_history: List[ChatMessage] = ToolCallingState.chat_history
90
+
91
+ subworkflow = ToolCallingWorkflow(
92
+ parent_state=self.state,
93
+ context=WorkflowContext.create_from(self._context),
94
+ )
95
+
96
+ terminal_event = subworkflow.run()
97
+
98
+ if terminal_event.name == "workflow.execution.paused":
99
+ raise NodeException(
100
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
101
+ message="Subworkflow unexpectedly paused",
102
+ )
103
+ elif terminal_event.name == "workflow.execution.fulfilled":
104
+ node_outputs = self.Outputs()
105
+
106
+ for output_descriptor, output_value in terminal_event.outputs:
107
+ setattr(node_outputs, output_descriptor.name, output_value)
108
+
109
+ return node_outputs
110
+ elif terminal_event.name == "workflow.execution.rejected":
111
+ raise Exception(f"Workflow execution rejected: {terminal_event.error}")
112
+
113
+ raise Exception(f"Unexpected workflow event: {terminal_event.name}")
114
+
115
+ def _build_graph(self) -> None:
116
+ self.tool_router_node = create_tool_router_node(
117
+ ml_model=self.ml_model,
118
+ blocks=self.blocks,
119
+ functions=self.functions,
120
+ prompt_inputs=self.prompt_inputs,
121
+ )
122
+
123
+ self._function_nodes = {
124
+ function.name: create_function_node(
125
+ function=function,
126
+ function_callable=cast(Callable[..., Any], self.function_callables[function.name]), # type: ignore
127
+ )
128
+ for function in self.functions
129
+ }
130
+
131
+ graph_set = set()
132
+
133
+ # Add connections from ports of router to function nodes and back to router
134
+ for function_name, FunctionNodeClass in self._function_nodes.items():
135
+ router_port = getattr(self.tool_router_node.Ports, function_name) # type: ignore # mypy thinks name is still optional # noqa: E501
136
+ edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
137
+ graph_set.add(edge_graph)
138
+
139
+ default_port = getattr(self.tool_router_node.Ports, "default")
140
+ graph_set.add(default_port)
141
+
142
+ self._graph = Graph.from_set(graph_set)
143
+
144
+ def _validate_functions(self) -> None:
145
+ for function in self.functions:
146
+ if function.name is None:
147
+ raise ValueError("Function name is required")
@@ -0,0 +1,132 @@
1
+ from collections.abc import Callable
2
+ import json
3
+ from typing import Any, Iterator, List, Optional, Type, cast
4
+
5
+ from vellum import ChatMessage, FunctionDefinition, PromptBlock
6
+ from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
7
+ from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
8
+ from vellum.client.types.variable_prompt_block import VariablePromptBlock
9
+ from vellum.workflows.nodes.bases import BaseNode
10
+ from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
11
+ from vellum.workflows.outputs.base import BaseOutput
12
+ from vellum.workflows.ports.port import Port
13
+ from vellum.workflows.references.lazy import LazyReference
14
+ from vellum.workflows.types.core import EntityInputsInterface
15
+
16
+
17
+ class FunctionNode(BaseNode):
18
+ """Node that executes a specific function."""
19
+
20
+ function: FunctionDefinition
21
+
22
+
23
+ class ToolRouterNode(InlinePromptNode):
24
+ def run(self) -> Iterator[BaseOutput]:
25
+ self.prompt_inputs = {**self.prompt_inputs, "chat_history": self.state.chat_history} # type: ignore
26
+ generator = super().run()
27
+ for output in generator:
28
+ if output.name == "results" and output.value:
29
+ values = cast(List[Any], output.value)
30
+ if values and len(values) > 0:
31
+ if values[0].type == "STRING":
32
+ self.state.chat_history.append(ChatMessage(role="ASSISTANT", text=values[0].value))
33
+ elif values[0].type == "FUNCTION_CALL":
34
+ function_call = values[0].value
35
+ if function_call is not None:
36
+ self.state.chat_history.append(
37
+ ChatMessage(
38
+ role="FUNCTION",
39
+ content=FunctionCallChatMessageContent(
40
+ value=FunctionCallChatMessageContentValue(
41
+ name=function_call.name,
42
+ arguments=function_call.arguments,
43
+ id=function_call.id,
44
+ ),
45
+ ),
46
+ )
47
+ )
48
+ yield output
49
+
50
+
51
+ def create_tool_router_node(
52
+ ml_model: str,
53
+ blocks: List[PromptBlock],
54
+ functions: List[FunctionDefinition],
55
+ prompt_inputs: Optional[EntityInputsInterface],
56
+ ) -> Type[ToolRouterNode]:
57
+ Ports = type("Ports", (), {})
58
+ for function in functions:
59
+ if function.name is None:
60
+ # We should not raise an error here since we filter out functions without names
61
+ raise ValueError("Function name is required")
62
+
63
+ function_name = function.name
64
+ port_condition = LazyReference(
65
+ lambda: (
66
+ ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
67
+ & ToolRouterNode.Outputs.results[0]["value"]["name"].equals(function_name)
68
+ )
69
+ )
70
+ port = Port.on_if(port_condition)
71
+ setattr(Ports, function_name, port)
72
+
73
+ setattr(Ports, "default", Port.on_else())
74
+
75
+ # Add a chat history block to blocks
76
+ blocks.append(
77
+ VariablePromptBlock(
78
+ block_type="VARIABLE",
79
+ input_variable="chat_history",
80
+ state=None,
81
+ cache_config=None,
82
+ )
83
+ )
84
+
85
+ node = type(
86
+ "ToolRouterNode",
87
+ (ToolRouterNode,),
88
+ {
89
+ "ml_model": ml_model,
90
+ "blocks": blocks,
91
+ "functions": functions,
92
+ "prompt_inputs": prompt_inputs,
93
+ "Ports": Ports,
94
+ "__module__": __name__,
95
+ },
96
+ )
97
+
98
+ return node
99
+
100
+
101
+ def create_function_node(function: FunctionDefinition, function_callable: Callable[..., Any]) -> Type[FunctionNode]:
102
+ """
103
+ Create a FunctionNode class for a given function.
104
+
105
+ This ensures the callable is properly registered and can be called with the expected arguments.
106
+ """
107
+
108
+ # Create a class-level wrapper that calls the original function
109
+ def execute_function(self) -> BaseNode.Outputs:
110
+ outputs = self.state.meta.node_outputs.get(ToolRouterNode.Outputs.text)
111
+ # first parse into json
112
+ outputs = json.loads(outputs)
113
+ arguments = outputs["arguments"]
114
+
115
+ # Call the original function directly with the arguments
116
+ result = function_callable(**arguments)
117
+
118
+ self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
119
+
120
+ return self.Outputs()
121
+
122
+ node = type(
123
+ f"FunctionNode_{function.name}",
124
+ (FunctionNode,),
125
+ {
126
+ "function": function,
127
+ "run": execute_function,
128
+ "__module__": __name__,
129
+ },
130
+ )
131
+
132
+ return node
@@ -57,10 +57,12 @@ def create_adornment(
57
57
  class Subworkflow(BaseWorkflow):
58
58
  graph = inner_cls
59
59
 
60
- # mypy is wrong here, this works and is defined
61
- class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
60
+ class Outputs(BaseWorkflow.Outputs):
62
61
  pass
63
62
 
63
+ for output_reference in inner_cls.Outputs:
64
+ setattr(Subworkflow.Outputs, output_reference.name, output_reference)
65
+
64
66
  dynamic_module = f"{inner_cls.__module__}.{inner_cls.__name__}.{ADORNMENT_MODULE_NAME}"
65
67
  # This dynamic module allows calls to `type_hints` to work
66
68
  sys.modules[dynamic_module] = ModuleType(dynamic_module)
@@ -147,8 +147,9 @@ class _BaseOutputsMeta(type):
147
147
  instance = vars(cls).get(name, undefined)
148
148
  if instance is undefined:
149
149
  for base in cls.__mro__[1:]:
150
- if hasattr(base, name):
151
- instance = getattr(base, name)
150
+ inherited_output_reference = getattr(base, name, undefined)
151
+ if isinstance(inherited_output_reference, OutputReference):
152
+ instance = inherited_output_reference.instance
152
153
  break
153
154
 
154
155
  types = infer_types(cls, name)
@@ -1,4 +1,6 @@
1
+ from functools import cached_property
1
2
  from queue import Queue
3
+ from uuid import UUID, uuid4
2
4
  from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Tuple, Type, TypeVar, cast
3
5
 
4
6
  from pydantic import GetCoreSchemaHandler
@@ -31,6 +33,24 @@ class OutputReference(BaseDescriptor[_OutputType], Generic[_OutputType]):
31
33
  def outputs_class(self) -> Type["BaseOutputs"]:
32
34
  return self._outputs_class
33
35
 
36
+ @cached_property
37
+ def id(self) -> UUID:
38
+ self._outputs_class = self._outputs_class
39
+
40
+ node_class = getattr(self._outputs_class, "_node_class", None)
41
+ if not node_class:
42
+ return uuid4()
43
+
44
+ output_ids = getattr(node_class, "__output_ids__", {})
45
+ if not isinstance(output_ids, dict):
46
+ return uuid4()
47
+
48
+ output_id = output_ids.get(self.name)
49
+ if not isinstance(output_id, UUID):
50
+ return uuid4()
51
+
52
+ return output_id
53
+
34
54
  def resolve(self, state: "BaseState") -> _OutputType:
35
55
  node_output = state.meta.node_outputs.get(self, undefined)
36
56
  if isinstance(node_output, Queue):
@@ -7,13 +7,14 @@ import logging
7
7
  from queue import Queue
8
8
  from threading import Lock
9
9
  from uuid import UUID, uuid4
10
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
10
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
11
11
  from typing_extensions import dataclass_transform
12
12
 
13
13
  from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
14
14
  from pydantic_core import core_schema
15
15
 
16
16
  from vellum.core.pydantic_utilities import UniversalBaseModel
17
+ from vellum.utils.uuid import is_valid_uuid
17
18
  from vellum.workflows.constants import undefined
18
19
  from vellum.workflows.edges.edge import Edge
19
20
  from vellum.workflows.inputs.base import BaseInputs
@@ -109,18 +110,30 @@ class NodeExecutionCache:
109
110
  self._node_executions_queued = defaultdict(list)
110
111
 
111
112
  @classmethod
112
- def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
113
+ def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
113
114
  cache = cls()
114
115
 
116
+ def get_node_class(node_id: Any) -> Optional[Type["BaseNode"]]:
117
+ if not isinstance(node_id, str):
118
+ return None
119
+
120
+ if is_valid_uuid(node_id):
121
+ return nodes.get(UUID(node_id))
122
+
123
+ return nodes.get(node_id)
124
+
115
125
  dependencies_invoked = raw_data.get("dependencies_invoked")
116
126
  if isinstance(dependencies_invoked, dict):
117
127
  for execution_id, dependencies in dependencies_invoked.items():
118
- cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
128
+ dependency_classes = {get_node_class(dep) for dep in dependencies}
129
+ cache._dependencies_invoked[UUID(execution_id)] = {
130
+ dep_class for dep_class in dependency_classes if dep_class is not None
131
+ }
119
132
 
120
133
  node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
121
134
  if isinstance(node_executions_fulfilled, dict):
122
135
  for node, execution_ids in node_executions_fulfilled.items():
123
- node_class = nodes.get(node)
136
+ node_class = get_node_class(node)
124
137
  if not node_class:
125
138
  continue
126
139
 
@@ -131,7 +144,7 @@ class NodeExecutionCache:
131
144
  node_executions_initiated = raw_data.get("node_executions_initiated")
132
145
  if isinstance(node_executions_initiated, dict):
133
146
  for node, execution_ids in node_executions_initiated.items():
134
- node_class = nodes.get(node)
147
+ node_class = get_node_class(node)
135
148
  if not node_class:
136
149
  continue
137
150
 
@@ -142,7 +155,7 @@ class NodeExecutionCache:
142
155
  node_executions_queued = raw_data.get("node_executions_queued")
143
156
  if isinstance(node_executions_queued, dict):
144
157
  for node, execution_ids in node_executions_queued.items():
145
- node_class = nodes.get(node)
158
+ node_class = get_node_class(node)
146
159
  if not node_class:
147
160
  continue
148
161
 
@@ -193,17 +206,18 @@ class NodeExecutionCache:
193
206
  def dump(self) -> Dict[str, Any]:
194
207
  return {
195
208
  "dependencies_invoked": {
196
- str(execution_id): [str(dep) for dep in dependencies]
209
+ str(execution_id): [str(dep.__id__) for dep in dependencies]
197
210
  for execution_id, dependencies in self._dependencies_invoked.items()
198
211
  },
199
212
  "node_executions_initiated": {
200
- str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
213
+ str(node.__id__): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
201
214
  },
202
215
  "node_executions_fulfilled": {
203
- str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
216
+ str(node.__id__): execution_ids.dump()
217
+ for node, execution_ids in self._node_executions_fulfilled.items()
204
218
  },
205
219
  "node_executions_queued": {
206
- str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
220
+ str(node.__id__): execution_ids for node, execution_ids in self._node_executions_queued.items()
207
221
  },
208
222
  }
209
223
 
@@ -279,7 +293,7 @@ class StateMeta(UniversalBaseModel):
279
293
 
280
294
  @field_serializer("node_outputs")
281
295
  def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
282
- return {str(descriptor): value for descriptor, value in node_outputs.items()}
296
+ return {str(descriptor.id): value for descriptor, value in node_outputs.items()}
283
297
 
284
298
  @field_validator("node_outputs", mode="before")
285
299
  @classmethod
@@ -290,15 +304,22 @@ class StateMeta(UniversalBaseModel):
290
304
  return node_outputs
291
305
 
292
306
  raw_workflow_nodes = workflow_definition.get_nodes()
293
- workflow_node_outputs = {}
307
+ workflow_node_outputs: Dict[Union[str, UUID], OutputReference] = {}
294
308
  for node in raw_workflow_nodes:
295
309
  for output in node.Outputs:
296
310
  workflow_node_outputs[str(output)] = output
311
+ output_id = node.__output_ids__.get(output.name)
312
+ if output_id:
313
+ workflow_node_outputs[output_id] = output
297
314
 
298
315
  node_output_keys = list(node_outputs.keys())
299
316
  deserialized_node_outputs = {}
300
317
  for node_output_key in node_output_keys:
301
- output_reference = workflow_node_outputs.get(node_output_key)
318
+ if is_valid_uuid(node_output_key):
319
+ output_reference = workflow_node_outputs.get(UUID(node_output_key))
320
+ else:
321
+ output_reference = workflow_node_outputs.get(node_output_key)
322
+
302
323
  if not output_reference:
303
324
  continue
304
325
 
@@ -316,10 +337,11 @@ class StateMeta(UniversalBaseModel):
316
337
  if not workflow_definition:
317
338
  return node_execution_cache
318
339
 
319
- nodes_cache: Dict[str, Type["BaseNode"]] = {}
340
+ nodes_cache: Dict[Union[str, UUID], Type["BaseNode"]] = {}
320
341
  raw_workflow_nodes = workflow_definition.get_nodes()
321
342
  for node in raw_workflow_nodes:
322
343
  nodes_cache[str(node)] = node
344
+ nodes_cache[node.__id__] = node
323
345
 
324
346
  return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
325
347
 
@@ -47,6 +47,9 @@ class MockNode(BaseNode):
47
47
  baz: str
48
48
 
49
49
 
50
+ MOCK_NODE_OUTPUT_ID = "e4dc3136-0c27-4bda-b3ab-ea355d5219d6"
51
+
52
+
50
53
  def test_state_snapshot__node_attribute_edit():
51
54
  # GIVEN an initial state instance
52
55
  state = MockState(foo="bar")
@@ -144,7 +147,7 @@ def test_state_json_serialization__with_node_output_updates():
144
147
  json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
145
148
 
146
149
  # THEN the state is serialized correctly
147
- assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": "hello"}
150
+ assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: "hello"}
148
151
 
149
152
 
150
153
  def test_state_deepcopy__with_external_input_updates():
@@ -185,7 +188,7 @@ def test_state_json_serialization__with_queue():
185
188
  json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
186
189
 
187
190
  # THEN the state is serialized correctly with the queue turned into a list
188
- assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": ["test1", "test2"]}
191
+ assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: ["test1", "test2"]}
189
192
 
190
193
 
191
194
  def test_state_snapshot__deepcopy_fails__logs_error(mock_deepcopy, mock_logger):
@@ -37,3 +37,14 @@ class Stack(Generic[_T]):
37
37
 
38
38
  def dump(self) -> List[_T]:
39
39
  return [item for item in self._items][::-1]
40
+
41
+ @classmethod
42
+ def from_list(cls, items: List[_T]) -> "Stack[_T]":
43
+ stack = cls()
44
+ stack.extend(items)
45
+ return stack
46
+
47
+ def __eq__(self, other: object) -> bool:
48
+ if not isinstance(other, Stack):
49
+ return False
50
+ return self._items == other._items
@@ -80,6 +80,11 @@ class _BaseWorkflowMeta(type):
80
80
  def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
81
81
  if "graph" not in dct:
82
82
  dct["graph"] = set()
83
+ for base in bases:
84
+ base_graph = getattr(base, "graph", None)
85
+ if base_graph:
86
+ dct["graph"] = base_graph
87
+ break
83
88
 
84
89
  if "Outputs" in dct:
85
90
  outputs_class = dct["Outputs"]