vellum-ai 0.14.39__py3-none-any.whl → 0.14.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +138 -1
- vellum/client/resources/ad_hoc/client.py +311 -1
- vellum/client/resources/deployments/client.py +2 -2
- vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
- vellum/workflows/nodes/core/try_node/node.py +1 -2
- vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +125 -0
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +128 -0
- vellum/workflows/nodes/utils.py +4 -2
- vellum/workflows/outputs/base.py +3 -2
- vellum/workflows/references/output.py +20 -0
- vellum/workflows/state/base.py +36 -14
- vellum/workflows/state/tests/test_state.py +5 -2
- vellum/workflows/types/stack.py +11 -0
- vellum/workflows/workflows/base.py +5 -0
- vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/RECORD +84 -80
- vellum_cli/push.py +0 -2
- vellum_ee/workflows/display/base.py +14 -1
- vellum_ee/workflows/display/nodes/base_node_display.py +91 -19
- vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
- vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +54 -0
- vellum_ee/workflows/display/nodes/vellum/api_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +4 -4
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/conditional_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/error_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/final_output_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/merge_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/note_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -4
- vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/search_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/templating_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
- vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +4 -4
- vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
- vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
- vellum_ee/workflows/display/types.py +5 -4
- vellum_ee/workflows/display/utils/exceptions.py +7 -0
- vellum_ee/workflows/display/utils/registry.py +37 -0
- vellum_ee/workflows/display/utils/vellum.py +2 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +277 -47
- vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
- vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
- vellum_ee/workflows/display/nodes/base_node_vellum_display.py +0 -40
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,125 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
from typing import Any, ClassVar, List, Optional
|
3
|
+
|
4
|
+
from vellum import ChatMessage, PromptBlock
|
5
|
+
from vellum.workflows.context import execution_context, get_parent_context
|
6
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
7
|
+
from vellum.workflows.exceptions import NodeException
|
8
|
+
from vellum.workflows.graph.graph import Graph
|
9
|
+
from vellum.workflows.inputs.base import BaseInputs
|
10
|
+
from vellum.workflows.nodes.bases import BaseNode
|
11
|
+
from vellum.workflows.nodes.experimental.tool_calling_node.utils import (
|
12
|
+
ToolRouterNode,
|
13
|
+
create_function_node,
|
14
|
+
create_tool_router_node,
|
15
|
+
)
|
16
|
+
from vellum.workflows.outputs.base import BaseOutputs
|
17
|
+
from vellum.workflows.state.base import BaseState
|
18
|
+
from vellum.workflows.state.context import WorkflowContext
|
19
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
20
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
21
|
+
|
22
|
+
|
23
|
+
class ToolCallingNode(BaseNode):
|
24
|
+
"""
|
25
|
+
A Node that dynamically invokes the provided functions to the underlying Prompt
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
ml_model: str - The model to use for tool calling (e.g., "gpt-4o-mini")
|
29
|
+
blocks: List[PromptBlock] - The prompt blocks to use (same format as InlinePromptNode)
|
30
|
+
functions: List[FunctionDefinition] - The functions that can be called
|
31
|
+
function_callables: List[Callable] - The callables that can be called
|
32
|
+
prompt_inputs: Optional[EntityInputsInterface] - Mapping of input variable names to values
|
33
|
+
"""
|
34
|
+
|
35
|
+
ml_model: ClassVar[str] = "gpt-4o-mini"
|
36
|
+
blocks: ClassVar[List[PromptBlock]] = []
|
37
|
+
functions: ClassVar[List[Callable[..., Any]]] = []
|
38
|
+
prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
|
39
|
+
# TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
|
40
|
+
max_tool_calls: ClassVar[int] = 1
|
41
|
+
|
42
|
+
class Outputs(BaseOutputs):
|
43
|
+
"""
|
44
|
+
The outputs of the ToolCallingNode.
|
45
|
+
|
46
|
+
text: The final text response after tool calling
|
47
|
+
chat_history: The complete chat history including tool calls
|
48
|
+
"""
|
49
|
+
|
50
|
+
text: str = ""
|
51
|
+
chat_history: List[ChatMessage] = []
|
52
|
+
|
53
|
+
def run(self) -> Outputs:
|
54
|
+
"""
|
55
|
+
Run the tool calling workflow.
|
56
|
+
|
57
|
+
This dynamically builds a graph with router and function nodes,
|
58
|
+
then executes the workflow.
|
59
|
+
"""
|
60
|
+
|
61
|
+
self._build_graph()
|
62
|
+
|
63
|
+
with execution_context(parent_context=get_parent_context()):
|
64
|
+
|
65
|
+
class ToolCallingState(BaseState):
|
66
|
+
chat_history: List[ChatMessage] = []
|
67
|
+
|
68
|
+
class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
|
69
|
+
graph = self._graph
|
70
|
+
|
71
|
+
class Outputs(BaseWorkflow.Outputs):
|
72
|
+
text: str = ToolRouterNode.Outputs.text
|
73
|
+
chat_history: List[ChatMessage] = ToolCallingState.chat_history
|
74
|
+
|
75
|
+
subworkflow = ToolCallingWorkflow(
|
76
|
+
parent_state=self.state,
|
77
|
+
context=WorkflowContext.create_from(self._context),
|
78
|
+
)
|
79
|
+
|
80
|
+
terminal_event = subworkflow.run()
|
81
|
+
|
82
|
+
if terminal_event.name == "workflow.execution.paused":
|
83
|
+
raise NodeException(
|
84
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
85
|
+
message="Subworkflow unexpectedly paused",
|
86
|
+
)
|
87
|
+
elif terminal_event.name == "workflow.execution.fulfilled":
|
88
|
+
node_outputs = self.Outputs()
|
89
|
+
|
90
|
+
for output_descriptor, output_value in terminal_event.outputs:
|
91
|
+
setattr(node_outputs, output_descriptor.name, output_value)
|
92
|
+
|
93
|
+
return node_outputs
|
94
|
+
elif terminal_event.name == "workflow.execution.rejected":
|
95
|
+
raise Exception(f"Workflow execution rejected: {terminal_event.error}")
|
96
|
+
|
97
|
+
raise Exception(f"Unexpected workflow event: {terminal_event.name}")
|
98
|
+
|
99
|
+
def _build_graph(self) -> None:
|
100
|
+
self.tool_router_node = create_tool_router_node(
|
101
|
+
ml_model=self.ml_model,
|
102
|
+
blocks=self.blocks,
|
103
|
+
functions=self.functions,
|
104
|
+
prompt_inputs=self.prompt_inputs,
|
105
|
+
)
|
106
|
+
|
107
|
+
self._function_nodes = {
|
108
|
+
function.__name__: create_function_node(
|
109
|
+
function=function,
|
110
|
+
)
|
111
|
+
for function in self.functions
|
112
|
+
}
|
113
|
+
|
114
|
+
graph_set = set()
|
115
|
+
|
116
|
+
# Add connections from ports of router to function nodes and back to router
|
117
|
+
for function_name, FunctionNodeClass in self._function_nodes.items():
|
118
|
+
router_port = getattr(self.tool_router_node.Ports, function_name)
|
119
|
+
edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
|
120
|
+
graph_set.add(edge_graph)
|
121
|
+
|
122
|
+
default_port = getattr(self.tool_router_node.Ports, "default")
|
123
|
+
graph_set.add(default_port)
|
124
|
+
|
125
|
+
self._graph = Graph.from_set(graph_set)
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
import json
|
3
|
+
from typing import Any, Iterator, List, Optional, Type, cast
|
4
|
+
|
5
|
+
from vellum import ChatMessage, FunctionDefinition, PromptBlock
|
6
|
+
from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
|
7
|
+
from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
|
8
|
+
from vellum.client.types.variable_prompt_block import VariablePromptBlock
|
9
|
+
from vellum.workflows.nodes.bases import BaseNode
|
10
|
+
from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
|
11
|
+
from vellum.workflows.outputs.base import BaseOutput
|
12
|
+
from vellum.workflows.ports.port import Port
|
13
|
+
from vellum.workflows.references.lazy import LazyReference
|
14
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
15
|
+
|
16
|
+
|
17
|
+
class FunctionNode(BaseNode):
|
18
|
+
"""Node that executes a specific function."""
|
19
|
+
|
20
|
+
function: FunctionDefinition
|
21
|
+
|
22
|
+
|
23
|
+
class ToolRouterNode(InlinePromptNode):
|
24
|
+
def run(self) -> Iterator[BaseOutput]:
|
25
|
+
self.prompt_inputs = {**self.prompt_inputs, "chat_history": self.state.chat_history} # type: ignore
|
26
|
+
generator = super().run()
|
27
|
+
for output in generator:
|
28
|
+
if output.name == "results" and output.value:
|
29
|
+
values = cast(List[Any], output.value)
|
30
|
+
if values and len(values) > 0:
|
31
|
+
if values[0].type == "STRING":
|
32
|
+
self.state.chat_history.append(ChatMessage(role="ASSISTANT", text=values[0].value))
|
33
|
+
elif values[0].type == "FUNCTION_CALL":
|
34
|
+
function_call = values[0].value
|
35
|
+
if function_call is not None:
|
36
|
+
self.state.chat_history.append(
|
37
|
+
ChatMessage(
|
38
|
+
role="ASSISTANT",
|
39
|
+
content=FunctionCallChatMessageContent(
|
40
|
+
value=FunctionCallChatMessageContentValue(
|
41
|
+
name=function_call.name,
|
42
|
+
arguments=function_call.arguments,
|
43
|
+
id=function_call.id,
|
44
|
+
),
|
45
|
+
),
|
46
|
+
)
|
47
|
+
)
|
48
|
+
yield output
|
49
|
+
|
50
|
+
|
51
|
+
def create_tool_router_node(
|
52
|
+
ml_model: str,
|
53
|
+
blocks: List[PromptBlock],
|
54
|
+
functions: List[Callable[..., Any]],
|
55
|
+
prompt_inputs: Optional[EntityInputsInterface],
|
56
|
+
) -> Type[ToolRouterNode]:
|
57
|
+
Ports = type("Ports", (), {})
|
58
|
+
for function in functions:
|
59
|
+
function_name = function.__name__
|
60
|
+
port_condition = LazyReference(
|
61
|
+
lambda: (
|
62
|
+
ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
|
63
|
+
& ToolRouterNode.Outputs.results[0]["value"]["name"].equals(function_name)
|
64
|
+
)
|
65
|
+
)
|
66
|
+
port = Port.on_if(port_condition)
|
67
|
+
setattr(Ports, function_name, port)
|
68
|
+
|
69
|
+
setattr(Ports, "default", Port.on_else())
|
70
|
+
|
71
|
+
# Add a chat history block to blocks
|
72
|
+
blocks.append(
|
73
|
+
VariablePromptBlock(
|
74
|
+
block_type="VARIABLE",
|
75
|
+
input_variable="chat_history",
|
76
|
+
state=None,
|
77
|
+
cache_config=None,
|
78
|
+
)
|
79
|
+
)
|
80
|
+
|
81
|
+
node = type(
|
82
|
+
"ToolRouterNode",
|
83
|
+
(ToolRouterNode,),
|
84
|
+
{
|
85
|
+
"ml_model": ml_model,
|
86
|
+
"blocks": blocks,
|
87
|
+
"functions": functions,
|
88
|
+
"prompt_inputs": prompt_inputs,
|
89
|
+
"Ports": Ports,
|
90
|
+
"__module__": __name__,
|
91
|
+
},
|
92
|
+
)
|
93
|
+
|
94
|
+
return node
|
95
|
+
|
96
|
+
|
97
|
+
def create_function_node(function: Callable[..., Any]) -> Type[FunctionNode]:
|
98
|
+
"""
|
99
|
+
Create a FunctionNode class for a given function.
|
100
|
+
|
101
|
+
This ensures the callable is properly registered and can be called with the expected arguments.
|
102
|
+
"""
|
103
|
+
|
104
|
+
# Create a class-level wrapper that calls the original function
|
105
|
+
def execute_function(self) -> BaseNode.Outputs:
|
106
|
+
outputs = self.state.meta.node_outputs.get(ToolRouterNode.Outputs.text)
|
107
|
+
# first parse into json
|
108
|
+
outputs = json.loads(outputs)
|
109
|
+
arguments = outputs["arguments"]
|
110
|
+
|
111
|
+
# Call the original function directly with the arguments
|
112
|
+
result = function(**arguments)
|
113
|
+
|
114
|
+
self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
|
115
|
+
|
116
|
+
return self.Outputs()
|
117
|
+
|
118
|
+
node = type(
|
119
|
+
f"FunctionNode_{function.__name__}",
|
120
|
+
(FunctionNode,),
|
121
|
+
{
|
122
|
+
"function": function,
|
123
|
+
"run": execute_function,
|
124
|
+
"__module__": __name__,
|
125
|
+
},
|
126
|
+
)
|
127
|
+
|
128
|
+
return node
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -57,10 +57,12 @@ def create_adornment(
|
|
57
57
|
class Subworkflow(BaseWorkflow):
|
58
58
|
graph = inner_cls
|
59
59
|
|
60
|
-
|
61
|
-
class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
|
60
|
+
class Outputs(BaseWorkflow.Outputs):
|
62
61
|
pass
|
63
62
|
|
63
|
+
for output_reference in inner_cls.Outputs:
|
64
|
+
setattr(Subworkflow.Outputs, output_reference.name, output_reference)
|
65
|
+
|
64
66
|
dynamic_module = f"{inner_cls.__module__}.{inner_cls.__name__}.{ADORNMENT_MODULE_NAME}"
|
65
67
|
# This dynamic module allows calls to `type_hints` to work
|
66
68
|
sys.modules[dynamic_module] = ModuleType(dynamic_module)
|
vellum/workflows/outputs/base.py
CHANGED
@@ -147,8 +147,9 @@ class _BaseOutputsMeta(type):
|
|
147
147
|
instance = vars(cls).get(name, undefined)
|
148
148
|
if instance is undefined:
|
149
149
|
for base in cls.__mro__[1:]:
|
150
|
-
|
151
|
-
|
150
|
+
inherited_output_reference = getattr(base, name, undefined)
|
151
|
+
if isinstance(inherited_output_reference, OutputReference):
|
152
|
+
instance = inherited_output_reference.instance
|
152
153
|
break
|
153
154
|
|
154
155
|
types = infer_types(cls, name)
|
@@ -1,4 +1,6 @@
|
|
1
|
+
from functools import cached_property
|
1
2
|
from queue import Queue
|
3
|
+
from uuid import UUID, uuid4
|
2
4
|
from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Tuple, Type, TypeVar, cast
|
3
5
|
|
4
6
|
from pydantic import GetCoreSchemaHandler
|
@@ -31,6 +33,24 @@ class OutputReference(BaseDescriptor[_OutputType], Generic[_OutputType]):
|
|
31
33
|
def outputs_class(self) -> Type["BaseOutputs"]:
|
32
34
|
return self._outputs_class
|
33
35
|
|
36
|
+
@cached_property
|
37
|
+
def id(self) -> UUID:
|
38
|
+
self._outputs_class = self._outputs_class
|
39
|
+
|
40
|
+
node_class = getattr(self._outputs_class, "_node_class", None)
|
41
|
+
if not node_class:
|
42
|
+
return uuid4()
|
43
|
+
|
44
|
+
output_ids = getattr(node_class, "__output_ids__", {})
|
45
|
+
if not isinstance(output_ids, dict):
|
46
|
+
return uuid4()
|
47
|
+
|
48
|
+
output_id = output_ids.get(self.name)
|
49
|
+
if not isinstance(output_id, UUID):
|
50
|
+
return uuid4()
|
51
|
+
|
52
|
+
return output_id
|
53
|
+
|
34
54
|
def resolve(self, state: "BaseState") -> _OutputType:
|
35
55
|
node_output = state.meta.node_outputs.get(self, undefined)
|
36
56
|
if isinstance(node_output, Queue):
|
vellum/workflows/state/base.py
CHANGED
@@ -7,13 +7,14 @@ import logging
|
|
7
7
|
from queue import Queue
|
8
8
|
from threading import Lock
|
9
9
|
from uuid import UUID, uuid4
|
10
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
|
11
11
|
from typing_extensions import dataclass_transform
|
12
12
|
|
13
13
|
from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
|
14
14
|
from pydantic_core import core_schema
|
15
15
|
|
16
16
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
17
|
+
from vellum.utils.uuid import is_valid_uuid
|
17
18
|
from vellum.workflows.constants import undefined
|
18
19
|
from vellum.workflows.edges.edge import Edge
|
19
20
|
from vellum.workflows.inputs.base import BaseInputs
|
@@ -109,18 +110,30 @@ class NodeExecutionCache:
|
|
109
110
|
self._node_executions_queued = defaultdict(list)
|
110
111
|
|
111
112
|
@classmethod
|
112
|
-
def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
|
113
|
+
def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
|
113
114
|
cache = cls()
|
114
115
|
|
116
|
+
def get_node_class(node_id: Any) -> Optional[Type["BaseNode"]]:
|
117
|
+
if not isinstance(node_id, str):
|
118
|
+
return None
|
119
|
+
|
120
|
+
if is_valid_uuid(node_id):
|
121
|
+
return nodes.get(UUID(node_id))
|
122
|
+
|
123
|
+
return nodes.get(node_id)
|
124
|
+
|
115
125
|
dependencies_invoked = raw_data.get("dependencies_invoked")
|
116
126
|
if isinstance(dependencies_invoked, dict):
|
117
127
|
for execution_id, dependencies in dependencies_invoked.items():
|
118
|
-
|
128
|
+
dependency_classes = {get_node_class(dep) for dep in dependencies}
|
129
|
+
cache._dependencies_invoked[UUID(execution_id)] = {
|
130
|
+
dep_class for dep_class in dependency_classes if dep_class is not None
|
131
|
+
}
|
119
132
|
|
120
133
|
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
121
134
|
if isinstance(node_executions_fulfilled, dict):
|
122
135
|
for node, execution_ids in node_executions_fulfilled.items():
|
123
|
-
node_class =
|
136
|
+
node_class = get_node_class(node)
|
124
137
|
if not node_class:
|
125
138
|
continue
|
126
139
|
|
@@ -131,7 +144,7 @@ class NodeExecutionCache:
|
|
131
144
|
node_executions_initiated = raw_data.get("node_executions_initiated")
|
132
145
|
if isinstance(node_executions_initiated, dict):
|
133
146
|
for node, execution_ids in node_executions_initiated.items():
|
134
|
-
node_class =
|
147
|
+
node_class = get_node_class(node)
|
135
148
|
if not node_class:
|
136
149
|
continue
|
137
150
|
|
@@ -142,7 +155,7 @@ class NodeExecutionCache:
|
|
142
155
|
node_executions_queued = raw_data.get("node_executions_queued")
|
143
156
|
if isinstance(node_executions_queued, dict):
|
144
157
|
for node, execution_ids in node_executions_queued.items():
|
145
|
-
node_class =
|
158
|
+
node_class = get_node_class(node)
|
146
159
|
if not node_class:
|
147
160
|
continue
|
148
161
|
|
@@ -193,17 +206,18 @@ class NodeExecutionCache:
|
|
193
206
|
def dump(self) -> Dict[str, Any]:
|
194
207
|
return {
|
195
208
|
"dependencies_invoked": {
|
196
|
-
str(execution_id): [str(dep) for dep in dependencies]
|
209
|
+
str(execution_id): [str(dep.__id__) for dep in dependencies]
|
197
210
|
for execution_id, dependencies in self._dependencies_invoked.items()
|
198
211
|
},
|
199
212
|
"node_executions_initiated": {
|
200
|
-
str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
213
|
+
str(node.__id__): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
201
214
|
},
|
202
215
|
"node_executions_fulfilled": {
|
203
|
-
str(node): execution_ids.dump()
|
216
|
+
str(node.__id__): execution_ids.dump()
|
217
|
+
for node, execution_ids in self._node_executions_fulfilled.items()
|
204
218
|
},
|
205
219
|
"node_executions_queued": {
|
206
|
-
str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
220
|
+
str(node.__id__): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
207
221
|
},
|
208
222
|
}
|
209
223
|
|
@@ -279,7 +293,7 @@ class StateMeta(UniversalBaseModel):
|
|
279
293
|
|
280
294
|
@field_serializer("node_outputs")
|
281
295
|
def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
|
282
|
-
return {str(descriptor): value for descriptor, value in node_outputs.items()}
|
296
|
+
return {str(descriptor.id): value for descriptor, value in node_outputs.items()}
|
283
297
|
|
284
298
|
@field_validator("node_outputs", mode="before")
|
285
299
|
@classmethod
|
@@ -290,15 +304,22 @@ class StateMeta(UniversalBaseModel):
|
|
290
304
|
return node_outputs
|
291
305
|
|
292
306
|
raw_workflow_nodes = workflow_definition.get_nodes()
|
293
|
-
workflow_node_outputs = {}
|
307
|
+
workflow_node_outputs: Dict[Union[str, UUID], OutputReference] = {}
|
294
308
|
for node in raw_workflow_nodes:
|
295
309
|
for output in node.Outputs:
|
296
310
|
workflow_node_outputs[str(output)] = output
|
311
|
+
output_id = node.__output_ids__.get(output.name)
|
312
|
+
if output_id:
|
313
|
+
workflow_node_outputs[output_id] = output
|
297
314
|
|
298
315
|
node_output_keys = list(node_outputs.keys())
|
299
316
|
deserialized_node_outputs = {}
|
300
317
|
for node_output_key in node_output_keys:
|
301
|
-
|
318
|
+
if is_valid_uuid(node_output_key):
|
319
|
+
output_reference = workflow_node_outputs.get(UUID(node_output_key))
|
320
|
+
else:
|
321
|
+
output_reference = workflow_node_outputs.get(node_output_key)
|
322
|
+
|
302
323
|
if not output_reference:
|
303
324
|
continue
|
304
325
|
|
@@ -316,10 +337,11 @@ class StateMeta(UniversalBaseModel):
|
|
316
337
|
if not workflow_definition:
|
317
338
|
return node_execution_cache
|
318
339
|
|
319
|
-
nodes_cache: Dict[str, Type["BaseNode"]] = {}
|
340
|
+
nodes_cache: Dict[Union[str, UUID], Type["BaseNode"]] = {}
|
320
341
|
raw_workflow_nodes = workflow_definition.get_nodes()
|
321
342
|
for node in raw_workflow_nodes:
|
322
343
|
nodes_cache[str(node)] = node
|
344
|
+
nodes_cache[node.__id__] = node
|
323
345
|
|
324
346
|
return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
|
325
347
|
|
@@ -47,6 +47,9 @@ class MockNode(BaseNode):
|
|
47
47
|
baz: str
|
48
48
|
|
49
49
|
|
50
|
+
MOCK_NODE_OUTPUT_ID = "e4dc3136-0c27-4bda-b3ab-ea355d5219d6"
|
51
|
+
|
52
|
+
|
50
53
|
def test_state_snapshot__node_attribute_edit():
|
51
54
|
# GIVEN an initial state instance
|
52
55
|
state = MockState(foo="bar")
|
@@ -144,7 +147,7 @@ def test_state_json_serialization__with_node_output_updates():
|
|
144
147
|
json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
|
145
148
|
|
146
149
|
# THEN the state is serialized correctly
|
147
|
-
assert json_state["meta"]["node_outputs"] == {
|
150
|
+
assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: "hello"}
|
148
151
|
|
149
152
|
|
150
153
|
def test_state_deepcopy__with_external_input_updates():
|
@@ -185,7 +188,7 @@ def test_state_json_serialization__with_queue():
|
|
185
188
|
json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
|
186
189
|
|
187
190
|
# THEN the state is serialized correctly with the queue turned into a list
|
188
|
-
assert json_state["meta"]["node_outputs"] == {
|
191
|
+
assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: ["test1", "test2"]}
|
189
192
|
|
190
193
|
|
191
194
|
def test_state_snapshot__deepcopy_fails__logs_error(mock_deepcopy, mock_logger):
|
vellum/workflows/types/stack.py
CHANGED
@@ -37,3 +37,14 @@ class Stack(Generic[_T]):
|
|
37
37
|
|
38
38
|
def dump(self) -> List[_T]:
|
39
39
|
return [item for item in self._items][::-1]
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_list(cls, items: List[_T]) -> "Stack[_T]":
|
43
|
+
stack = cls()
|
44
|
+
stack.extend(items)
|
45
|
+
return stack
|
46
|
+
|
47
|
+
def __eq__(self, other: object) -> bool:
|
48
|
+
if not isinstance(other, Stack):
|
49
|
+
return False
|
50
|
+
return self._items == other._items
|
@@ -80,6 +80,11 @@ class _BaseWorkflowMeta(type):
|
|
80
80
|
def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
81
81
|
if "graph" not in dct:
|
82
82
|
dct["graph"] = set()
|
83
|
+
for base in bases:
|
84
|
+
base_graph = getattr(base, "graph", None)
|
85
|
+
if base_graph:
|
86
|
+
dct["graph"] = base_graph
|
87
|
+
break
|
83
88
|
|
84
89
|
if "Outputs" in dct:
|
85
90
|
outputs_class = dct["Outputs"]
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import pytest
|
2
|
-
from uuid import uuid4
|
2
|
+
from uuid import UUID, uuid4
|
3
3
|
|
4
4
|
from vellum.workflows.edges.edge import Edge
|
5
5
|
from vellum.workflows.inputs.base import BaseInputs
|
@@ -7,6 +7,7 @@ from vellum.workflows.nodes.bases.base import BaseNode
|
|
7
7
|
from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
|
8
8
|
from vellum.workflows.outputs.base import BaseOutputs
|
9
9
|
from vellum.workflows.state.base import BaseState
|
10
|
+
from vellum.workflows.types.stack import Stack
|
10
11
|
from vellum.workflows.workflows.base import BaseWorkflow
|
11
12
|
|
12
13
|
|
@@ -326,7 +327,6 @@ def test_workflow__unsupported_graph_item():
|
|
326
327
|
|
327
328
|
|
328
329
|
def test_base_workflow__deserialize_state():
|
329
|
-
|
330
330
|
# GIVEN a state definition
|
331
331
|
class State(BaseState):
|
332
332
|
bar: str
|
@@ -340,6 +340,14 @@ def test_base_workflow__deserialize_state():
|
|
340
340
|
class Outputs(BaseNode.Outputs):
|
341
341
|
foo: str
|
342
342
|
|
343
|
+
# AND a node id for node A
|
344
|
+
node_a_id = uuid4()
|
345
|
+
NodeA.__id__ = node_a_id
|
346
|
+
|
347
|
+
# AND an output id for NodeA.Outputs.foo
|
348
|
+
node_a_output_id = uuid4()
|
349
|
+
NodeA.__output_ids__["foo"] = node_a_output_id
|
350
|
+
|
343
351
|
# AND a workflow that uses all three
|
344
352
|
class TestWorkflow(BaseWorkflow[Inputs, State]):
|
345
353
|
graph = NodeA
|
@@ -356,20 +364,20 @@ def test_base_workflow__deserialize_state():
|
|
356
364
|
"updated_ts": "2025-04-14T19:22:18.504902",
|
357
365
|
"external_inputs": {},
|
358
366
|
"node_outputs": {
|
359
|
-
|
367
|
+
str(node_a_output_id): "My node A output foo",
|
360
368
|
},
|
361
369
|
"node_execution_cache": {
|
362
370
|
"dependencies_invoked": {
|
363
|
-
last_span_id: [
|
371
|
+
last_span_id: [str(node_a_id)],
|
364
372
|
},
|
365
373
|
"node_executions_initiated": {
|
366
|
-
|
374
|
+
str(node_a_id): [last_span_id],
|
367
375
|
},
|
368
376
|
"node_executions_fulfilled": {
|
369
|
-
|
377
|
+
str(node_a_id): [last_span_id],
|
370
378
|
},
|
371
379
|
"node_executions_queued": {
|
372
|
-
|
380
|
+
str(node_a_id): [],
|
373
381
|
},
|
374
382
|
},
|
375
383
|
"parent": None,
|
@@ -384,8 +392,87 @@ def test_base_workflow__deserialize_state():
|
|
384
392
|
assert isinstance(state.meta.workflow_inputs, Inputs)
|
385
393
|
assert state.meta.workflow_inputs.baz == "My input baz"
|
386
394
|
|
387
|
-
# AND the node execution cache should deserialize
|
388
|
-
assert state.meta.node_execution_cache
|
395
|
+
# AND the node execution cache should deserialize correctly
|
396
|
+
assert state.meta.node_execution_cache._node_executions_initiated == {NodeA: {UUID(last_span_id)}}
|
397
|
+
assert state.meta.node_execution_cache._node_executions_fulfilled == {NodeA: Stack.from_list([UUID(last_span_id)])}
|
398
|
+
assert state.meta.node_execution_cache._node_executions_queued == {NodeA: []}
|
399
|
+
assert state.meta.node_execution_cache._dependencies_invoked == {UUID(last_span_id): {NodeA}}
|
400
|
+
|
401
|
+
|
402
|
+
def test_base_workflow__deserialize_legacy_node_execution_cache():
|
403
|
+
|
404
|
+
# GIVEN a state definition
|
405
|
+
class State(BaseState):
|
406
|
+
bar: str
|
407
|
+
|
408
|
+
# AND a node
|
409
|
+
class NodeA(BaseNode):
|
410
|
+
class Outputs(BaseNode.Outputs):
|
411
|
+
foo: str
|
412
|
+
|
413
|
+
# AND a workflow that uses both
|
414
|
+
class TestWorkflow(BaseWorkflow[BaseInputs, State]):
|
415
|
+
graph = NodeA
|
416
|
+
|
417
|
+
# WHEN we deserialize the state that had a legacy node execution cache format
|
418
|
+
last_span_id = str(uuid4())
|
419
|
+
node_a_full_name = "vellum.workflows.workflows.tests.test_base_workflow.test_base_workflow__deserialize_legacy_node_execution_cache.<locals>.NodeA" # noqa: E501
|
420
|
+
state = TestWorkflow.deserialize_state(
|
421
|
+
{
|
422
|
+
"meta": {
|
423
|
+
"node_execution_cache": {
|
424
|
+
"dependencies_invoked": {
|
425
|
+
last_span_id: [node_a_full_name],
|
426
|
+
},
|
427
|
+
"node_executions_initiated": {
|
428
|
+
node_a_full_name: [last_span_id],
|
429
|
+
},
|
430
|
+
"node_executions_fulfilled": {
|
431
|
+
node_a_full_name: [last_span_id],
|
432
|
+
},
|
433
|
+
"node_executions_queued": {
|
434
|
+
node_a_full_name: [],
|
435
|
+
},
|
436
|
+
},
|
437
|
+
},
|
438
|
+
},
|
439
|
+
)
|
440
|
+
|
441
|
+
# THEN the node execution cache should deserialize correctly
|
442
|
+
assert state.meta.node_execution_cache._node_executions_initiated == {NodeA: {UUID(last_span_id)}}
|
443
|
+
assert state.meta.node_execution_cache._node_executions_fulfilled == {NodeA: Stack.from_list([UUID(last_span_id)])}
|
444
|
+
assert state.meta.node_execution_cache._node_executions_queued == {NodeA: []}
|
445
|
+
assert state.meta.node_execution_cache._dependencies_invoked == {UUID(last_span_id): {NodeA}}
|
446
|
+
|
447
|
+
|
448
|
+
def test_base_workflow__deserialize_legacy_node_outputs():
|
449
|
+
|
450
|
+
# GIVEN a state definition
|
451
|
+
class State(BaseState):
|
452
|
+
bar: str
|
453
|
+
|
454
|
+
# AND a node
|
455
|
+
class NodeA(BaseNode):
|
456
|
+
class Outputs(BaseNode.Outputs):
|
457
|
+
foo: str
|
458
|
+
|
459
|
+
# AND a workflow that uses both
|
460
|
+
class TestWorkflow(BaseWorkflow[BaseInputs, State]):
|
461
|
+
graph = NodeA
|
462
|
+
|
463
|
+
# WHEN we deserialize the state that had a legacy node execution cache format
|
464
|
+
state = TestWorkflow.deserialize_state(
|
465
|
+
{
|
466
|
+
"meta": {
|
467
|
+
"node_outputs": {
|
468
|
+
"test_base_workflow__deserialize_legacy_node_outputs.<locals>.NodeA.Outputs.foo": "My node A output foo", # noqa: E501
|
469
|
+
},
|
470
|
+
},
|
471
|
+
},
|
472
|
+
)
|
473
|
+
|
474
|
+
# THEN the node execution cache should deserialize correctly
|
475
|
+
assert state.meta.node_outputs == {NodeA.Outputs.foo: "My node A output foo"}
|
389
476
|
|
390
477
|
|
391
478
|
def test_base_workflow__deserialize_state_with_optional_inputs():
|