vellum-ai 0.14.39__py3-none-any.whl → 0.14.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
- vellum/workflows/nodes/core/try_node/node.py +1 -2
- vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
- vellum/workflows/nodes/utils.py +4 -2
- vellum/workflows/outputs/base.py +3 -2
- vellum/workflows/references/output.py +20 -0
- vellum/workflows/state/base.py +36 -14
- vellum/workflows/state/tests/test_state.py +5 -2
- vellum/workflows/types/stack.py +11 -0
- vellum/workflows/workflows/base.py +5 -0
- vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +67 -62
- vellum_cli/push.py +0 -2
- vellum_ee/workflows/display/base.py +14 -1
- vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
- vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
- vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
- vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
- vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
- vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
- vellum_ee/workflows/display/types.py +5 -4
- vellum_ee/workflows/display/utils/exceptions.py +7 -0
- vellum_ee/workflows/display/utils/registry.py +37 -0
- vellum_ee/workflows/display/utils/vellum.py +2 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
- vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
- vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.40",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -4,11 +4,13 @@ from typing import Optional
|
|
4
4
|
|
5
5
|
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
6
6
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
7
|
+
from vellum.workflows.constants import undefined
|
7
8
|
from vellum.workflows.descriptors.tests.test_utils import FixtureState
|
8
9
|
from vellum.workflows.inputs.base import BaseInputs
|
9
10
|
from vellum.workflows.nodes import FinalOutputNode
|
10
11
|
from vellum.workflows.nodes.bases.base import BaseNode
|
11
12
|
from vellum.workflows.outputs.base import BaseOutputs
|
13
|
+
from vellum.workflows.references.output import OutputReference
|
12
14
|
from vellum.workflows.state.base import BaseState, StateMeta
|
13
15
|
|
14
16
|
|
@@ -259,3 +261,25 @@ def test_resolve_value__for_falsy_values(falsy_value, expected_type):
|
|
259
261
|
|
260
262
|
# THEN the output has the correct value
|
261
263
|
assert falsy_output.value == falsy_value
|
264
|
+
|
265
|
+
|
266
|
+
def test_node_outputs__inherits_instance():
|
267
|
+
# GIVEN a node with two outputs, one with and one without a default instance
|
268
|
+
class MyNode(BaseNode):
|
269
|
+
class Outputs:
|
270
|
+
foo: str
|
271
|
+
bar = "hello"
|
272
|
+
|
273
|
+
# AND a node that inherits from MyNode
|
274
|
+
class InheritedNode(MyNode):
|
275
|
+
pass
|
276
|
+
|
277
|
+
# WHEN we reference each output
|
278
|
+
foo_output = InheritedNode.Outputs.foo
|
279
|
+
bar_output = InheritedNode.Outputs.bar
|
280
|
+
|
281
|
+
# THEN the output reference instances are correct
|
282
|
+
assert isinstance(foo_output, OutputReference)
|
283
|
+
assert foo_output.instance is undefined
|
284
|
+
assert isinstance(bar_output, OutputReference)
|
285
|
+
assert bar_output.instance == "hello"
|
@@ -4,7 +4,6 @@ from vellum.workflows.context import execution_context, get_parent_context
|
|
4
4
|
from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
|
5
5
|
from vellum.workflows.events.workflow import is_workflow_event
|
6
6
|
from vellum.workflows.exceptions import NodeException
|
7
|
-
from vellum.workflows.nodes.bases import BaseNode
|
8
7
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
9
8
|
from vellum.workflows.nodes.utils import create_adornment
|
10
9
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
@@ -24,7 +23,7 @@ class TryNode(BaseAdornmentNode[StateType], Generic[StateType]):
|
|
24
23
|
|
25
24
|
on_error_code: Optional[WorkflowErrorCode] = None
|
26
25
|
|
27
|
-
class Outputs(
|
26
|
+
class Outputs(BaseAdornmentNode.Outputs):
|
28
27
|
error: Optional[WorkflowError] = None
|
29
28
|
|
30
29
|
def run(self) -> Iterator[BaseOutput]:
|
@@ -0,0 +1,147 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
from typing import Any, ClassVar, Dict, List, Optional, cast
|
3
|
+
|
4
|
+
from vellum import ChatMessage, FunctionDefinition, PromptBlock
|
5
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
6
|
+
from vellum.workflows.context import execution_context, get_parent_context
|
7
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
8
|
+
from vellum.workflows.exceptions import NodeException
|
9
|
+
from vellum.workflows.graph.graph import Graph
|
10
|
+
from vellum.workflows.inputs.base import BaseInputs
|
11
|
+
from vellum.workflows.nodes.bases import BaseNode
|
12
|
+
from vellum.workflows.nodes.experimental.tool_calling_node.utils import (
|
13
|
+
ToolRouterNode,
|
14
|
+
create_function_node,
|
15
|
+
create_tool_router_node,
|
16
|
+
)
|
17
|
+
from vellum.workflows.outputs.base import BaseOutputs
|
18
|
+
from vellum.workflows.state.base import BaseState
|
19
|
+
from vellum.workflows.state.context import WorkflowContext
|
20
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
21
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
22
|
+
|
23
|
+
|
24
|
+
class ToolCallingNode(BaseNode):
|
25
|
+
"""
|
26
|
+
A Node that dynamically invokes the provided functions to the underlying Prompt
|
27
|
+
|
28
|
+
Attributes:
|
29
|
+
ml_model: str - The model to use for tool calling (e.g., "gpt-4o-mini")
|
30
|
+
blocks: List[PromptBlock] - The prompt blocks to use (same format as InlinePromptNode)
|
31
|
+
functions: List[FunctionDefinition] - The functions that can be called
|
32
|
+
function_callables: List[Callable] - The callables that can be called
|
33
|
+
prompt_inputs: Optional[EntityInputsInterface] - Mapping of input variable names to values
|
34
|
+
"""
|
35
|
+
|
36
|
+
ml_model: ClassVar[str] = "gpt-4o-mini"
|
37
|
+
blocks: ClassVar[List[PromptBlock]] = []
|
38
|
+
functions: ClassVar[List[FunctionDefinition]] = []
|
39
|
+
function_callables: ClassVar[Dict[str, Callable[..., Any]]] = {}
|
40
|
+
prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
|
41
|
+
# TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
|
42
|
+
max_tool_calls: ClassVar[int] = 1
|
43
|
+
|
44
|
+
class Outputs(BaseOutputs):
|
45
|
+
"""
|
46
|
+
The outputs of the ToolCallingNode.
|
47
|
+
|
48
|
+
text: The final text response after tool calling
|
49
|
+
chat_history: The complete chat history including tool calls
|
50
|
+
"""
|
51
|
+
|
52
|
+
text: str = ""
|
53
|
+
chat_history: List[ChatMessage] = []
|
54
|
+
|
55
|
+
def run(self) -> Outputs:
|
56
|
+
"""
|
57
|
+
Run the tool calling workflow.
|
58
|
+
|
59
|
+
This dynamically builds a graph with router and function nodes,
|
60
|
+
then executes the workflow.
|
61
|
+
"""
|
62
|
+
self._validate_functions()
|
63
|
+
|
64
|
+
initial_chat_history = []
|
65
|
+
|
66
|
+
# Extract chat history from prompt inputs if available
|
67
|
+
if self.prompt_inputs and "chat_history" in self.prompt_inputs:
|
68
|
+
chat_history_input = self.prompt_inputs["chat_history"]
|
69
|
+
if isinstance(chat_history_input, list) and all(
|
70
|
+
isinstance(msg, (ChatMessage, ChatMessageRequest)) for msg in chat_history_input
|
71
|
+
):
|
72
|
+
initial_chat_history = [
|
73
|
+
msg if isinstance(msg, ChatMessage) else ChatMessage.model_validate(msg.model_dump())
|
74
|
+
for msg in chat_history_input
|
75
|
+
]
|
76
|
+
|
77
|
+
self._build_graph()
|
78
|
+
|
79
|
+
with execution_context(parent_context=get_parent_context()):
|
80
|
+
|
81
|
+
class ToolCallingState(BaseState):
|
82
|
+
chat_history: List[ChatMessage] = initial_chat_history
|
83
|
+
|
84
|
+
class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
|
85
|
+
graph = self._graph
|
86
|
+
|
87
|
+
class Outputs(BaseWorkflow.Outputs):
|
88
|
+
text: str = ToolRouterNode.Outputs.text
|
89
|
+
chat_history: List[ChatMessage] = ToolCallingState.chat_history
|
90
|
+
|
91
|
+
subworkflow = ToolCallingWorkflow(
|
92
|
+
parent_state=self.state,
|
93
|
+
context=WorkflowContext.create_from(self._context),
|
94
|
+
)
|
95
|
+
|
96
|
+
terminal_event = subworkflow.run()
|
97
|
+
|
98
|
+
if terminal_event.name == "workflow.execution.paused":
|
99
|
+
raise NodeException(
|
100
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
101
|
+
message="Subworkflow unexpectedly paused",
|
102
|
+
)
|
103
|
+
elif terminal_event.name == "workflow.execution.fulfilled":
|
104
|
+
node_outputs = self.Outputs()
|
105
|
+
|
106
|
+
for output_descriptor, output_value in terminal_event.outputs:
|
107
|
+
setattr(node_outputs, output_descriptor.name, output_value)
|
108
|
+
|
109
|
+
return node_outputs
|
110
|
+
elif terminal_event.name == "workflow.execution.rejected":
|
111
|
+
raise Exception(f"Workflow execution rejected: {terminal_event.error}")
|
112
|
+
|
113
|
+
raise Exception(f"Unexpected workflow event: {terminal_event.name}")
|
114
|
+
|
115
|
+
def _build_graph(self) -> None:
|
116
|
+
self.tool_router_node = create_tool_router_node(
|
117
|
+
ml_model=self.ml_model,
|
118
|
+
blocks=self.blocks,
|
119
|
+
functions=self.functions,
|
120
|
+
prompt_inputs=self.prompt_inputs,
|
121
|
+
)
|
122
|
+
|
123
|
+
self._function_nodes = {
|
124
|
+
function.name: create_function_node(
|
125
|
+
function=function,
|
126
|
+
function_callable=cast(Callable[..., Any], self.function_callables[function.name]), # type: ignore
|
127
|
+
)
|
128
|
+
for function in self.functions
|
129
|
+
}
|
130
|
+
|
131
|
+
graph_set = set()
|
132
|
+
|
133
|
+
# Add connections from ports of router to function nodes and back to router
|
134
|
+
for function_name, FunctionNodeClass in self._function_nodes.items():
|
135
|
+
router_port = getattr(self.tool_router_node.Ports, function_name) # type: ignore # mypy thinks name is still optional # noqa: E501
|
136
|
+
edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
|
137
|
+
graph_set.add(edge_graph)
|
138
|
+
|
139
|
+
default_port = getattr(self.tool_router_node.Ports, "default")
|
140
|
+
graph_set.add(default_port)
|
141
|
+
|
142
|
+
self._graph = Graph.from_set(graph_set)
|
143
|
+
|
144
|
+
def _validate_functions(self) -> None:
|
145
|
+
for function in self.functions:
|
146
|
+
if function.name is None:
|
147
|
+
raise ValueError("Function name is required")
|
@@ -0,0 +1,132 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
import json
|
3
|
+
from typing import Any, Iterator, List, Optional, Type, cast
|
4
|
+
|
5
|
+
from vellum import ChatMessage, FunctionDefinition, PromptBlock
|
6
|
+
from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
|
7
|
+
from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
|
8
|
+
from vellum.client.types.variable_prompt_block import VariablePromptBlock
|
9
|
+
from vellum.workflows.nodes.bases import BaseNode
|
10
|
+
from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
|
11
|
+
from vellum.workflows.outputs.base import BaseOutput
|
12
|
+
from vellum.workflows.ports.port import Port
|
13
|
+
from vellum.workflows.references.lazy import LazyReference
|
14
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
15
|
+
|
16
|
+
|
17
|
+
class FunctionNode(BaseNode):
|
18
|
+
"""Node that executes a specific function."""
|
19
|
+
|
20
|
+
function: FunctionDefinition
|
21
|
+
|
22
|
+
|
23
|
+
class ToolRouterNode(InlinePromptNode):
|
24
|
+
def run(self) -> Iterator[BaseOutput]:
|
25
|
+
self.prompt_inputs = {**self.prompt_inputs, "chat_history": self.state.chat_history} # type: ignore
|
26
|
+
generator = super().run()
|
27
|
+
for output in generator:
|
28
|
+
if output.name == "results" and output.value:
|
29
|
+
values = cast(List[Any], output.value)
|
30
|
+
if values and len(values) > 0:
|
31
|
+
if values[0].type == "STRING":
|
32
|
+
self.state.chat_history.append(ChatMessage(role="ASSISTANT", text=values[0].value))
|
33
|
+
elif values[0].type == "FUNCTION_CALL":
|
34
|
+
function_call = values[0].value
|
35
|
+
if function_call is not None:
|
36
|
+
self.state.chat_history.append(
|
37
|
+
ChatMessage(
|
38
|
+
role="FUNCTION",
|
39
|
+
content=FunctionCallChatMessageContent(
|
40
|
+
value=FunctionCallChatMessageContentValue(
|
41
|
+
name=function_call.name,
|
42
|
+
arguments=function_call.arguments,
|
43
|
+
id=function_call.id,
|
44
|
+
),
|
45
|
+
),
|
46
|
+
)
|
47
|
+
)
|
48
|
+
yield output
|
49
|
+
|
50
|
+
|
51
|
+
def create_tool_router_node(
|
52
|
+
ml_model: str,
|
53
|
+
blocks: List[PromptBlock],
|
54
|
+
functions: List[FunctionDefinition],
|
55
|
+
prompt_inputs: Optional[EntityInputsInterface],
|
56
|
+
) -> Type[ToolRouterNode]:
|
57
|
+
Ports = type("Ports", (), {})
|
58
|
+
for function in functions:
|
59
|
+
if function.name is None:
|
60
|
+
# We should not raise an error here since we filter out functions without names
|
61
|
+
raise ValueError("Function name is required")
|
62
|
+
|
63
|
+
function_name = function.name
|
64
|
+
port_condition = LazyReference(
|
65
|
+
lambda: (
|
66
|
+
ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
|
67
|
+
& ToolRouterNode.Outputs.results[0]["value"]["name"].equals(function_name)
|
68
|
+
)
|
69
|
+
)
|
70
|
+
port = Port.on_if(port_condition)
|
71
|
+
setattr(Ports, function_name, port)
|
72
|
+
|
73
|
+
setattr(Ports, "default", Port.on_else())
|
74
|
+
|
75
|
+
# Add a chat history block to blocks
|
76
|
+
blocks.append(
|
77
|
+
VariablePromptBlock(
|
78
|
+
block_type="VARIABLE",
|
79
|
+
input_variable="chat_history",
|
80
|
+
state=None,
|
81
|
+
cache_config=None,
|
82
|
+
)
|
83
|
+
)
|
84
|
+
|
85
|
+
node = type(
|
86
|
+
"ToolRouterNode",
|
87
|
+
(ToolRouterNode,),
|
88
|
+
{
|
89
|
+
"ml_model": ml_model,
|
90
|
+
"blocks": blocks,
|
91
|
+
"functions": functions,
|
92
|
+
"prompt_inputs": prompt_inputs,
|
93
|
+
"Ports": Ports,
|
94
|
+
"__module__": __name__,
|
95
|
+
},
|
96
|
+
)
|
97
|
+
|
98
|
+
return node
|
99
|
+
|
100
|
+
|
101
|
+
def create_function_node(function: FunctionDefinition, function_callable: Callable[..., Any]) -> Type[FunctionNode]:
|
102
|
+
"""
|
103
|
+
Create a FunctionNode class for a given function.
|
104
|
+
|
105
|
+
This ensures the callable is properly registered and can be called with the expected arguments.
|
106
|
+
"""
|
107
|
+
|
108
|
+
# Create a class-level wrapper that calls the original function
|
109
|
+
def execute_function(self) -> BaseNode.Outputs:
|
110
|
+
outputs = self.state.meta.node_outputs.get(ToolRouterNode.Outputs.text)
|
111
|
+
# first parse into json
|
112
|
+
outputs = json.loads(outputs)
|
113
|
+
arguments = outputs["arguments"]
|
114
|
+
|
115
|
+
# Call the original function directly with the arguments
|
116
|
+
result = function_callable(**arguments)
|
117
|
+
|
118
|
+
self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
|
119
|
+
|
120
|
+
return self.Outputs()
|
121
|
+
|
122
|
+
node = type(
|
123
|
+
f"FunctionNode_{function.name}",
|
124
|
+
(FunctionNode,),
|
125
|
+
{
|
126
|
+
"function": function,
|
127
|
+
"run": execute_function,
|
128
|
+
"__module__": __name__,
|
129
|
+
},
|
130
|
+
)
|
131
|
+
|
132
|
+
return node
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -57,10 +57,12 @@ def create_adornment(
|
|
57
57
|
class Subworkflow(BaseWorkflow):
|
58
58
|
graph = inner_cls
|
59
59
|
|
60
|
-
|
61
|
-
class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
|
60
|
+
class Outputs(BaseWorkflow.Outputs):
|
62
61
|
pass
|
63
62
|
|
63
|
+
for output_reference in inner_cls.Outputs:
|
64
|
+
setattr(Subworkflow.Outputs, output_reference.name, output_reference)
|
65
|
+
|
64
66
|
dynamic_module = f"{inner_cls.__module__}.{inner_cls.__name__}.{ADORNMENT_MODULE_NAME}"
|
65
67
|
# This dynamic module allows calls to `type_hints` to work
|
66
68
|
sys.modules[dynamic_module] = ModuleType(dynamic_module)
|
vellum/workflows/outputs/base.py
CHANGED
@@ -147,8 +147,9 @@ class _BaseOutputsMeta(type):
|
|
147
147
|
instance = vars(cls).get(name, undefined)
|
148
148
|
if instance is undefined:
|
149
149
|
for base in cls.__mro__[1:]:
|
150
|
-
|
151
|
-
|
150
|
+
inherited_output_reference = getattr(base, name, undefined)
|
151
|
+
if isinstance(inherited_output_reference, OutputReference):
|
152
|
+
instance = inherited_output_reference.instance
|
152
153
|
break
|
153
154
|
|
154
155
|
types = infer_types(cls, name)
|
@@ -1,4 +1,6 @@
|
|
1
|
+
from functools import cached_property
|
1
2
|
from queue import Queue
|
3
|
+
from uuid import UUID, uuid4
|
2
4
|
from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Tuple, Type, TypeVar, cast
|
3
5
|
|
4
6
|
from pydantic import GetCoreSchemaHandler
|
@@ -31,6 +33,24 @@ class OutputReference(BaseDescriptor[_OutputType], Generic[_OutputType]):
|
|
31
33
|
def outputs_class(self) -> Type["BaseOutputs"]:
|
32
34
|
return self._outputs_class
|
33
35
|
|
36
|
+
@cached_property
|
37
|
+
def id(self) -> UUID:
|
38
|
+
self._outputs_class = self._outputs_class
|
39
|
+
|
40
|
+
node_class = getattr(self._outputs_class, "_node_class", None)
|
41
|
+
if not node_class:
|
42
|
+
return uuid4()
|
43
|
+
|
44
|
+
output_ids = getattr(node_class, "__output_ids__", {})
|
45
|
+
if not isinstance(output_ids, dict):
|
46
|
+
return uuid4()
|
47
|
+
|
48
|
+
output_id = output_ids.get(self.name)
|
49
|
+
if not isinstance(output_id, UUID):
|
50
|
+
return uuid4()
|
51
|
+
|
52
|
+
return output_id
|
53
|
+
|
34
54
|
def resolve(self, state: "BaseState") -> _OutputType:
|
35
55
|
node_output = state.meta.node_outputs.get(self, undefined)
|
36
56
|
if isinstance(node_output, Queue):
|
vellum/workflows/state/base.py
CHANGED
@@ -7,13 +7,14 @@ import logging
|
|
7
7
|
from queue import Queue
|
8
8
|
from threading import Lock
|
9
9
|
from uuid import UUID, uuid4
|
10
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
|
11
11
|
from typing_extensions import dataclass_transform
|
12
12
|
|
13
13
|
from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
|
14
14
|
from pydantic_core import core_schema
|
15
15
|
|
16
16
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
17
|
+
from vellum.utils.uuid import is_valid_uuid
|
17
18
|
from vellum.workflows.constants import undefined
|
18
19
|
from vellum.workflows.edges.edge import Edge
|
19
20
|
from vellum.workflows.inputs.base import BaseInputs
|
@@ -109,18 +110,30 @@ class NodeExecutionCache:
|
|
109
110
|
self._node_executions_queued = defaultdict(list)
|
110
111
|
|
111
112
|
@classmethod
|
112
|
-
def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
|
113
|
+
def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
|
113
114
|
cache = cls()
|
114
115
|
|
116
|
+
def get_node_class(node_id: Any) -> Optional[Type["BaseNode"]]:
|
117
|
+
if not isinstance(node_id, str):
|
118
|
+
return None
|
119
|
+
|
120
|
+
if is_valid_uuid(node_id):
|
121
|
+
return nodes.get(UUID(node_id))
|
122
|
+
|
123
|
+
return nodes.get(node_id)
|
124
|
+
|
115
125
|
dependencies_invoked = raw_data.get("dependencies_invoked")
|
116
126
|
if isinstance(dependencies_invoked, dict):
|
117
127
|
for execution_id, dependencies in dependencies_invoked.items():
|
118
|
-
|
128
|
+
dependency_classes = {get_node_class(dep) for dep in dependencies}
|
129
|
+
cache._dependencies_invoked[UUID(execution_id)] = {
|
130
|
+
dep_class for dep_class in dependency_classes if dep_class is not None
|
131
|
+
}
|
119
132
|
|
120
133
|
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
121
134
|
if isinstance(node_executions_fulfilled, dict):
|
122
135
|
for node, execution_ids in node_executions_fulfilled.items():
|
123
|
-
node_class =
|
136
|
+
node_class = get_node_class(node)
|
124
137
|
if not node_class:
|
125
138
|
continue
|
126
139
|
|
@@ -131,7 +144,7 @@ class NodeExecutionCache:
|
|
131
144
|
node_executions_initiated = raw_data.get("node_executions_initiated")
|
132
145
|
if isinstance(node_executions_initiated, dict):
|
133
146
|
for node, execution_ids in node_executions_initiated.items():
|
134
|
-
node_class =
|
147
|
+
node_class = get_node_class(node)
|
135
148
|
if not node_class:
|
136
149
|
continue
|
137
150
|
|
@@ -142,7 +155,7 @@ class NodeExecutionCache:
|
|
142
155
|
node_executions_queued = raw_data.get("node_executions_queued")
|
143
156
|
if isinstance(node_executions_queued, dict):
|
144
157
|
for node, execution_ids in node_executions_queued.items():
|
145
|
-
node_class =
|
158
|
+
node_class = get_node_class(node)
|
146
159
|
if not node_class:
|
147
160
|
continue
|
148
161
|
|
@@ -193,17 +206,18 @@ class NodeExecutionCache:
|
|
193
206
|
def dump(self) -> Dict[str, Any]:
|
194
207
|
return {
|
195
208
|
"dependencies_invoked": {
|
196
|
-
str(execution_id): [str(dep) for dep in dependencies]
|
209
|
+
str(execution_id): [str(dep.__id__) for dep in dependencies]
|
197
210
|
for execution_id, dependencies in self._dependencies_invoked.items()
|
198
211
|
},
|
199
212
|
"node_executions_initiated": {
|
200
|
-
str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
213
|
+
str(node.__id__): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
201
214
|
},
|
202
215
|
"node_executions_fulfilled": {
|
203
|
-
str(node): execution_ids.dump()
|
216
|
+
str(node.__id__): execution_ids.dump()
|
217
|
+
for node, execution_ids in self._node_executions_fulfilled.items()
|
204
218
|
},
|
205
219
|
"node_executions_queued": {
|
206
|
-
str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
220
|
+
str(node.__id__): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
207
221
|
},
|
208
222
|
}
|
209
223
|
|
@@ -279,7 +293,7 @@ class StateMeta(UniversalBaseModel):
|
|
279
293
|
|
280
294
|
@field_serializer("node_outputs")
|
281
295
|
def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
|
282
|
-
return {str(descriptor): value for descriptor, value in node_outputs.items()}
|
296
|
+
return {str(descriptor.id): value for descriptor, value in node_outputs.items()}
|
283
297
|
|
284
298
|
@field_validator("node_outputs", mode="before")
|
285
299
|
@classmethod
|
@@ -290,15 +304,22 @@ class StateMeta(UniversalBaseModel):
|
|
290
304
|
return node_outputs
|
291
305
|
|
292
306
|
raw_workflow_nodes = workflow_definition.get_nodes()
|
293
|
-
workflow_node_outputs = {}
|
307
|
+
workflow_node_outputs: Dict[Union[str, UUID], OutputReference] = {}
|
294
308
|
for node in raw_workflow_nodes:
|
295
309
|
for output in node.Outputs:
|
296
310
|
workflow_node_outputs[str(output)] = output
|
311
|
+
output_id = node.__output_ids__.get(output.name)
|
312
|
+
if output_id:
|
313
|
+
workflow_node_outputs[output_id] = output
|
297
314
|
|
298
315
|
node_output_keys = list(node_outputs.keys())
|
299
316
|
deserialized_node_outputs = {}
|
300
317
|
for node_output_key in node_output_keys:
|
301
|
-
|
318
|
+
if is_valid_uuid(node_output_key):
|
319
|
+
output_reference = workflow_node_outputs.get(UUID(node_output_key))
|
320
|
+
else:
|
321
|
+
output_reference = workflow_node_outputs.get(node_output_key)
|
322
|
+
|
302
323
|
if not output_reference:
|
303
324
|
continue
|
304
325
|
|
@@ -316,10 +337,11 @@ class StateMeta(UniversalBaseModel):
|
|
316
337
|
if not workflow_definition:
|
317
338
|
return node_execution_cache
|
318
339
|
|
319
|
-
nodes_cache: Dict[str, Type["BaseNode"]] = {}
|
340
|
+
nodes_cache: Dict[Union[str, UUID], Type["BaseNode"]] = {}
|
320
341
|
raw_workflow_nodes = workflow_definition.get_nodes()
|
321
342
|
for node in raw_workflow_nodes:
|
322
343
|
nodes_cache[str(node)] = node
|
344
|
+
nodes_cache[node.__id__] = node
|
323
345
|
|
324
346
|
return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
|
325
347
|
|
@@ -47,6 +47,9 @@ class MockNode(BaseNode):
|
|
47
47
|
baz: str
|
48
48
|
|
49
49
|
|
50
|
+
MOCK_NODE_OUTPUT_ID = "e4dc3136-0c27-4bda-b3ab-ea355d5219d6"
|
51
|
+
|
52
|
+
|
50
53
|
def test_state_snapshot__node_attribute_edit():
|
51
54
|
# GIVEN an initial state instance
|
52
55
|
state = MockState(foo="bar")
|
@@ -144,7 +147,7 @@ def test_state_json_serialization__with_node_output_updates():
|
|
144
147
|
json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
|
145
148
|
|
146
149
|
# THEN the state is serialized correctly
|
147
|
-
assert json_state["meta"]["node_outputs"] == {
|
150
|
+
assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: "hello"}
|
148
151
|
|
149
152
|
|
150
153
|
def test_state_deepcopy__with_external_input_updates():
|
@@ -185,7 +188,7 @@ def test_state_json_serialization__with_queue():
|
|
185
188
|
json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
|
186
189
|
|
187
190
|
# THEN the state is serialized correctly with the queue turned into a list
|
188
|
-
assert json_state["meta"]["node_outputs"] == {
|
191
|
+
assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: ["test1", "test2"]}
|
189
192
|
|
190
193
|
|
191
194
|
def test_state_snapshot__deepcopy_fails__logs_error(mock_deepcopy, mock_logger):
|
vellum/workflows/types/stack.py
CHANGED
@@ -37,3 +37,14 @@ class Stack(Generic[_T]):
|
|
37
37
|
|
38
38
|
def dump(self) -> List[_T]:
|
39
39
|
return [item for item in self._items][::-1]
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_list(cls, items: List[_T]) -> "Stack[_T]":
|
43
|
+
stack = cls()
|
44
|
+
stack.extend(items)
|
45
|
+
return stack
|
46
|
+
|
47
|
+
def __eq__(self, other: object) -> bool:
|
48
|
+
if not isinstance(other, Stack):
|
49
|
+
return False
|
50
|
+
return self._items == other._items
|
@@ -80,6 +80,11 @@ class _BaseWorkflowMeta(type):
|
|
80
80
|
def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
81
81
|
if "graph" not in dct:
|
82
82
|
dct["graph"] = set()
|
83
|
+
for base in bases:
|
84
|
+
base_graph = getattr(base, "graph", None)
|
85
|
+
if base_graph:
|
86
|
+
dct["graph"] = base_graph
|
87
|
+
break
|
83
88
|
|
84
89
|
if "Outputs" in dct:
|
85
90
|
outputs_class = dct["Outputs"]
|