vellum-ai 0.14.56__py3-none-any.whl → 0.14.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/workflows/nodes/bases/base.py +19 -8
- vellum/workflows/nodes/core/retry_node/node.py +6 -0
- vellum/workflows/nodes/displayable/api_node/node.py +8 -1
- vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +66 -3
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +14 -10
- vellum/workflows/nodes/displayable/guardrail_node/node.py +13 -2
- vellum/workflows/nodes/displayable/guardrail_node/test_node.py +29 -0
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +3 -1
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +46 -8
- vellum/workflows/runner/runner.py +14 -10
- vellum/workflows/state/base.py +28 -10
- vellum/workflows/state/encoder.py +5 -1
- vellum/workflows/utils/functions.py +42 -1
- vellum/workflows/utils/tests/test_functions.py +156 -1
- vellum/workflows/workflows/tests/test_base_workflow.py +4 -4
- {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/RECORD +24 -23
- vellum_ee/workflows/display/nodes/vellum/tests/test_tool_calling_node.py +118 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +265 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +2 -1
- {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.58",
|
22
22
|
}
|
23
23
|
headers["X-API-KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -8,7 +8,6 @@ from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, Typ
|
|
8
8
|
from vellum.workflows.constants import undefined
|
9
9
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
10
10
|
from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
|
11
|
-
from vellum.workflows.edges.edge import Edge
|
12
11
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
13
12
|
from vellum.workflows.exceptions import NodeException
|
14
13
|
from vellum.workflows.graph import Graph
|
@@ -317,8 +316,15 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
317
316
|
"""
|
318
317
|
# Check if all dependencies have invoked this node
|
319
318
|
dependencies_invoked = state.meta.node_execution_cache._dependencies_invoked.get(node_span_id, set())
|
320
|
-
|
321
|
-
|
319
|
+
node_classes_invoked = {
|
320
|
+
state.meta.node_execution_cache.__node_execution_lookup__[dep]
|
321
|
+
for dep in dependencies_invoked
|
322
|
+
if dep in state.meta.node_execution_cache.__node_execution_lookup__
|
323
|
+
}
|
324
|
+
if len(node_classes_invoked) != len(dependencies):
|
325
|
+
return False
|
326
|
+
|
327
|
+
all_deps_invoked = all(dep in node_classes_invoked for dep in dependencies)
|
322
328
|
return all_deps_invoked
|
323
329
|
|
324
330
|
raise NodeException(
|
@@ -328,7 +334,7 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
328
334
|
|
329
335
|
@classmethod
|
330
336
|
def _queue_node_execution(
|
331
|
-
cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: Optional[
|
337
|
+
cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: Optional[UUID] = None
|
332
338
|
) -> UUID:
|
333
339
|
"""
|
334
340
|
Queues a future execution of a node, returning the span id of the execution.
|
@@ -341,19 +347,24 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
341
347
|
if not invoked_by:
|
342
348
|
return execution_id
|
343
349
|
|
350
|
+
if invoked_by not in state.meta.node_execution_cache.__node_execution_lookup__:
|
351
|
+
return execution_id
|
352
|
+
|
344
353
|
if cls.merge_behavior not in {MergeBehavior.AWAIT_ANY, MergeBehavior.AWAIT_ALL}:
|
354
|
+
# Keep track of the dependencies that have invoked this node
|
355
|
+
# This would be needed while climbing the history in the loop
|
356
|
+
state.meta.node_execution_cache._dependencies_invoked[execution_id].add(invoked_by)
|
345
357
|
return execution_id
|
346
358
|
|
347
|
-
source_node = invoked_by.from_port.node_class
|
348
359
|
for queued_node_execution_id in state.meta.node_execution_cache._node_executions_queued[cls.node_class]:
|
349
|
-
if
|
360
|
+
if invoked_by not in state.meta.node_execution_cache._dependencies_invoked[queued_node_execution_id]:
|
350
361
|
state.meta.node_execution_cache._invoke_dependency(
|
351
|
-
queued_node_execution_id, cls.node_class,
|
362
|
+
queued_node_execution_id, cls.node_class, invoked_by, dependencies
|
352
363
|
)
|
353
364
|
return queued_node_execution_id
|
354
365
|
|
355
366
|
state.meta.node_execution_cache._node_executions_queued[cls.node_class].append(execution_id)
|
356
|
-
state.meta.node_execution_cache._invoke_dependency(execution_id, cls.node_class,
|
367
|
+
state.meta.node_execution_cache._invoke_dependency(execution_id, cls.node_class, invoked_by, dependencies)
|
357
368
|
return execution_id
|
358
369
|
|
359
370
|
class Execution(metaclass=_BaseNodeExecutionMeta):
|
@@ -70,6 +70,12 @@ class RetryNode(BaseAdornmentNode[StateType], Generic[StateType]):
|
|
70
70
|
|
71
71
|
for output_descriptor, output_value in event.outputs:
|
72
72
|
setattr(node_outputs, output_descriptor.name, output_value)
|
73
|
+
|
74
|
+
if self.__wrapped_node__:
|
75
|
+
inner_desc = getattr(self.__wrapped_node__.Outputs, output_descriptor.name, None)
|
76
|
+
if inner_desc:
|
77
|
+
self.state.meta.node_outputs[inner_desc] = output_value
|
78
|
+
|
73
79
|
elif event.name == "workflow.execution.paused":
|
74
80
|
exception = NodeException(
|
75
81
|
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
@@ -47,11 +47,18 @@ class APINode(BaseAPINode):
|
|
47
47
|
self.bearer_token_value, VellumSecret
|
48
48
|
):
|
49
49
|
bearer_token = self.bearer_token_value
|
50
|
+
|
51
|
+
final_headers = {**headers, **header_overrides}
|
52
|
+
|
53
|
+
vellum_client_wrapper = self._context.vellum_client._client_wrapper
|
54
|
+
if self.url.startswith(vellum_client_wrapper._environment.default) and "X-API-Key" not in final_headers:
|
55
|
+
final_headers["X-API-Key"] = vellum_client_wrapper.api_key
|
56
|
+
|
50
57
|
return self._run(
|
51
58
|
method=self.method,
|
52
59
|
url=self.url,
|
53
60
|
data=self.data,
|
54
61
|
json=self.json,
|
55
|
-
headers=
|
62
|
+
headers=final_headers,
|
56
63
|
bearer_token=bearer_token,
|
57
64
|
)
|
@@ -19,7 +19,7 @@ def test_run_workflow__secrets(vellum_client):
|
|
19
19
|
class SimpleBaseAPINode(APINode):
|
20
20
|
method = APIRequestMethod.POST
|
21
21
|
authorization_type = AuthorizationType.BEARER_TOKEN
|
22
|
-
url = "https://
|
22
|
+
url = "https://example.vellum.ai"
|
23
23
|
json = {
|
24
24
|
"key": "value",
|
25
25
|
}
|
@@ -32,7 +32,9 @@ def test_run_workflow__secrets(vellum_client):
|
|
32
32
|
terminal = node.run()
|
33
33
|
|
34
34
|
assert vellum_client.execute_api.call_count == 1
|
35
|
+
assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
|
35
36
|
assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
|
37
|
+
assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
|
36
38
|
bearer_token = vellum_client.execute_api.call_args.kwargs["bearer_token"]
|
37
39
|
assert bearer_token == ClientVellumSecret(name="secret")
|
38
40
|
assert terminal.headers == {"X-Response-Header": "bar"}
|
@@ -45,7 +47,7 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
|
|
45
47
|
class SimpleAPINode(APINode):
|
46
48
|
method = APIRequestMethod.GET
|
47
49
|
authorization_type = AuthorizationType.BEARER_TOKEN
|
48
|
-
url = "https://
|
50
|
+
url = "https://example.vellum.ai"
|
49
51
|
json = {
|
50
52
|
"key": "value",
|
51
53
|
}
|
@@ -65,7 +67,9 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
|
|
65
67
|
|
66
68
|
# AND the API call should have been made
|
67
69
|
assert vellum_client.execute_api.call_count == 1
|
70
|
+
assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
|
68
71
|
assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
|
72
|
+
assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
|
69
73
|
|
70
74
|
|
71
75
|
def test_api_node_defaults_to_get_method(vellum_client):
|
@@ -80,7 +84,7 @@ def test_api_node_defaults_to_get_method(vellum_client):
|
|
80
84
|
# AND an API node without a method specified
|
81
85
|
class SimpleAPINodeWithoutMethod(APINode):
|
82
86
|
authorization_type = AuthorizationType.BEARER_TOKEN
|
83
|
-
url = "https://
|
87
|
+
url = "https://example.vellum.ai"
|
84
88
|
headers = {
|
85
89
|
"X-Test-Header": "foo",
|
86
90
|
}
|
@@ -95,3 +99,62 @@ def test_api_node_defaults_to_get_method(vellum_client):
|
|
95
99
|
assert vellum_client.execute_api.call_count == 1
|
96
100
|
method = vellum_client.execute_api.call_args.kwargs["method"]
|
97
101
|
assert method == APIRequestMethod.GET.value
|
102
|
+
|
103
|
+
|
104
|
+
def test_api_node__detects_client_environment_urls__adds_token(mock_httpx_transport, mock_requests, monkeypatch):
|
105
|
+
# GIVEN an API node with a URL pointing back to Vellum
|
106
|
+
class SimpleAPINodeToVellum(APINode):
|
107
|
+
url = "https://api.vellum.ai"
|
108
|
+
|
109
|
+
# AND a mock request sent to the Vellum API would return a 200
|
110
|
+
mock_response = mock_requests.get(
|
111
|
+
"https://api.vellum.ai",
|
112
|
+
status_code=200,
|
113
|
+
json={"data": [1, 2, 3]},
|
114
|
+
)
|
115
|
+
|
116
|
+
# AND an api key is set
|
117
|
+
monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
|
118
|
+
|
119
|
+
# WHEN we run the node
|
120
|
+
node = SimpleAPINodeToVellum()
|
121
|
+
node.run()
|
122
|
+
|
123
|
+
# THEN the execute_api method should not have been called
|
124
|
+
mock_httpx_transport.handle_request.assert_not_called()
|
125
|
+
|
126
|
+
# AND the vellum API should have been called with the correct headers
|
127
|
+
assert mock_response.last_request
|
128
|
+
assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-1234"
|
129
|
+
|
130
|
+
|
131
|
+
def test_api_node__detects_client_environment_urls__does_not_override_headers(
|
132
|
+
mock_httpx_transport, mock_requests, monkeypatch
|
133
|
+
):
|
134
|
+
# GIVEN an API node with a URL pointing back to Vellum
|
135
|
+
class SimpleAPINodeToVellum(APINode):
|
136
|
+
url = "https://api.vellum.ai"
|
137
|
+
headers = {
|
138
|
+
"X-API-Key": "vellum-api-key-5678",
|
139
|
+
}
|
140
|
+
|
141
|
+
# AND a mock request sent to the Vellum API would return a 200
|
142
|
+
mock_response = mock_requests.get(
|
143
|
+
"https://api.vellum.ai",
|
144
|
+
status_code=200,
|
145
|
+
json={"data": [1, 2, 3]},
|
146
|
+
)
|
147
|
+
|
148
|
+
# AND an api key is set
|
149
|
+
monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
|
150
|
+
|
151
|
+
# WHEN we run the node
|
152
|
+
node = SimpleAPINodeToVellum()
|
153
|
+
node.run()
|
154
|
+
|
155
|
+
# THEN the execute_api method should not have been called
|
156
|
+
mock_httpx_transport.handle_request.assert_not_called()
|
157
|
+
|
158
|
+
# AND the vellum API should have been called with the correct headers
|
159
|
+
assert mock_response.last_request
|
160
|
+
assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-5678"
|
@@ -30,8 +30,8 @@ from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePrompt
|
|
30
30
|
from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
|
31
31
|
from vellum.workflows.outputs import BaseOutput
|
32
32
|
from vellum.workflows.types import MergeBehavior
|
33
|
-
from vellum.workflows.types.generics import StateType
|
34
|
-
from vellum.workflows.utils.functions import compile_function_definition
|
33
|
+
from vellum.workflows.types.generics import StateType, is_workflow_class
|
34
|
+
from vellum.workflows.utils.functions import compile_function_definition, compile_workflow_function_definition
|
35
35
|
|
36
36
|
|
37
37
|
class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
@@ -97,14 +97,18 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
97
97
|
"execution_context": execution_context.model_dump(mode="json"),
|
98
98
|
**request_options.get("additional_body_parameters", {}),
|
99
99
|
}
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
]
|
105
|
-
|
106
|
-
|
107
|
-
|
100
|
+
|
101
|
+
normalized_functions: Optional[List[FunctionDefinition]] = None
|
102
|
+
|
103
|
+
if self.functions:
|
104
|
+
normalized_functions = []
|
105
|
+
for function in self.functions:
|
106
|
+
if isinstance(function, FunctionDefinition):
|
107
|
+
normalized_functions.append(function)
|
108
|
+
elif is_workflow_class(function):
|
109
|
+
normalized_functions.append(compile_workflow_function_definition(function))
|
110
|
+
else:
|
111
|
+
normalized_functions.append(compile_function_definition(function))
|
108
112
|
|
109
113
|
if self.settings and not self.settings.stream_enabled:
|
110
114
|
# This endpoint is returning a single event, so we need to wrap it in a generator
|
@@ -37,6 +37,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
37
37
|
score: float
|
38
38
|
normalized_score: Optional[float]
|
39
39
|
log: Optional[str]
|
40
|
+
reason: Optional[str]
|
40
41
|
|
41
42
|
def run(self) -> Outputs:
|
42
43
|
try:
|
@@ -54,10 +55,10 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
54
55
|
)
|
55
56
|
|
56
57
|
metric_outputs = {output.name: output.value for output in metric_execution.outputs}
|
57
|
-
|
58
58
|
SCORE_KEY = "score"
|
59
59
|
NORMALIZED_SCORE_KEY = "normalized_score"
|
60
60
|
LOG_KEY = "log"
|
61
|
+
REASON_KEY = "reason"
|
61
62
|
|
62
63
|
score = metric_outputs.get(SCORE_KEY)
|
63
64
|
if not isinstance(score, float):
|
@@ -87,7 +88,17 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
87
88
|
else:
|
88
89
|
log = None
|
89
90
|
|
90
|
-
|
91
|
+
if REASON_KEY in metric_outputs:
|
92
|
+
reason = metric_outputs.pop(REASON_KEY) or ""
|
93
|
+
if not isinstance(reason, str):
|
94
|
+
raise NodeException(
|
95
|
+
message="Metric execution reason output must be of type 'str'",
|
96
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
reason = None
|
100
|
+
|
101
|
+
return self.Outputs(score=score, normalized_score=normalized_score, log=log, reason=reason, **metric_outputs)
|
91
102
|
|
92
103
|
def _compile_metric_inputs(self) -> List[MetricDefinitionInput]:
|
93
104
|
# TODO: We may want to consolidate with prompt deployment input compilation
|
@@ -102,6 +102,35 @@ def test_run_guardrail_node__normalized_score_null(vellum_client):
|
|
102
102
|
assert exc_info.value.code == WorkflowErrorCode.INVALID_OUTPUTS
|
103
103
|
|
104
104
|
|
105
|
+
def test_run_guardrail_node__reason(vellum_client):
|
106
|
+
# GIVEN a Guardrail Node
|
107
|
+
class MyGuard(GuardrailNode):
|
108
|
+
metric_definition = "example_metric_definition"
|
109
|
+
metric_inputs = {}
|
110
|
+
|
111
|
+
# AND we know that the guardrail node will return a reason
|
112
|
+
mock_metric_execution = MetricDefinitionExecution(
|
113
|
+
outputs=[
|
114
|
+
TestSuiteRunMetricNumberOutput(
|
115
|
+
name="score",
|
116
|
+
value=0.6,
|
117
|
+
),
|
118
|
+
TestSuiteRunMetricStringOutput(
|
119
|
+
name="reason",
|
120
|
+
value="foo",
|
121
|
+
),
|
122
|
+
],
|
123
|
+
)
|
124
|
+
vellum_client.metric_definitions.execute_metric_definition.return_value = mock_metric_execution
|
125
|
+
|
126
|
+
# WHEN we run the Guardrail Node
|
127
|
+
outputs = MyGuard().run()
|
128
|
+
|
129
|
+
# THEN the workflow should have completed successfully
|
130
|
+
assert outputs.score == 0.6
|
131
|
+
assert outputs.reason == "foo"
|
132
|
+
|
133
|
+
|
105
134
|
def test_run_guardrail_node__api_error(vellum_client):
|
106
135
|
# GIVEN a Guardrail Node
|
107
136
|
class MyGuard(GuardrailNode):
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from collections.abc import Callable
|
2
2
|
from typing import Any, ClassVar, List, Optional
|
3
3
|
|
4
|
+
from pydash import snake_case
|
5
|
+
|
4
6
|
from vellum import ChatMessage, PromptBlock
|
5
7
|
from vellum.workflows.context import execution_context, get_parent_context
|
6
8
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -101,7 +103,7 @@ class ToolCallingNode(BaseNode):
|
|
101
103
|
)
|
102
104
|
|
103
105
|
self._function_nodes = {
|
104
|
-
function.__name__: create_function_node(
|
106
|
+
snake_case(function.__name__): create_function_node(
|
105
107
|
function=function,
|
106
108
|
tool_router_node=self.tool_router_node,
|
107
109
|
)
|
@@ -2,22 +2,30 @@ from collections.abc import Callable
|
|
2
2
|
import json
|
3
3
|
from typing import Any, Iterator, List, Optional, Type, cast
|
4
4
|
|
5
|
-
from
|
5
|
+
from pydash import snake_case
|
6
|
+
|
7
|
+
from vellum import ChatMessage, PromptBlock
|
6
8
|
from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
|
7
9
|
from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
|
10
|
+
from vellum.client.types.string_chat_message_content import StringChatMessageContent
|
8
11
|
from vellum.client.types.variable_prompt_block import VariablePromptBlock
|
12
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
13
|
+
from vellum.workflows.exceptions import NodeException
|
14
|
+
from vellum.workflows.inputs.base import BaseInputs
|
9
15
|
from vellum.workflows.nodes.bases import BaseNode
|
10
16
|
from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
|
11
17
|
from vellum.workflows.outputs.base import BaseOutput
|
12
18
|
from vellum.workflows.ports.port import Port
|
13
19
|
from vellum.workflows.references.lazy import LazyReference
|
20
|
+
from vellum.workflows.state.encoder import DefaultStateEncoder
|
14
21
|
from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
|
22
|
+
from vellum.workflows.types.generics import is_workflow_class
|
15
23
|
|
16
24
|
|
17
25
|
class FunctionNode(BaseNode):
|
18
26
|
"""Node that executes a specific function."""
|
19
27
|
|
20
|
-
|
28
|
+
pass
|
21
29
|
|
22
30
|
|
23
31
|
class ToolRouterNode(InlinePromptNode):
|
@@ -61,7 +69,7 @@ def create_tool_router_node(
|
|
61
69
|
# If we have functions, create dynamic ports for each function
|
62
70
|
Ports = type("Ports", (), {})
|
63
71
|
for function in functions:
|
64
|
-
function_name = function.__name__
|
72
|
+
function_name = snake_case(function.__name__)
|
65
73
|
|
66
74
|
# Avoid using lambda to capture function_name
|
67
75
|
# lambda will capture the function_name by reference,
|
@@ -127,10 +135,41 @@ def create_function_node(function: Callable[..., Any], tool_router_node: Type[To
|
|
127
135
|
outputs = json.loads(outputs)
|
128
136
|
arguments = outputs["arguments"]
|
129
137
|
|
130
|
-
# Call the
|
131
|
-
|
132
|
-
|
133
|
-
|
138
|
+
# Call the function based on its type
|
139
|
+
if is_workflow_class(function):
|
140
|
+
# Dynamically define an Inputs subclass of BaseInputs
|
141
|
+
Inputs = type(
|
142
|
+
"Inputs",
|
143
|
+
(BaseInputs,),
|
144
|
+
{"__annotations__": {k: type(v) for k, v in arguments.items()}},
|
145
|
+
)
|
146
|
+
|
147
|
+
# Create an instance with arguments
|
148
|
+
inputs_instance = Inputs(**arguments)
|
149
|
+
|
150
|
+
workflow = function()
|
151
|
+
terminal_event = workflow.run(
|
152
|
+
inputs=inputs_instance,
|
153
|
+
)
|
154
|
+
if terminal_event.name == "workflow.execution.paused":
|
155
|
+
raise NodeException(
|
156
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
157
|
+
message="Subworkflow unexpectedly paused",
|
158
|
+
)
|
159
|
+
elif terminal_event.name == "workflow.execution.fulfilled":
|
160
|
+
result = terminal_event.outputs
|
161
|
+
elif terminal_event.name == "workflow.execution.rejected":
|
162
|
+
raise Exception(f"Workflow execution rejected: {terminal_event.error}")
|
163
|
+
else:
|
164
|
+
# If it's a regular callable, call it directly
|
165
|
+
result = function(**arguments)
|
166
|
+
|
167
|
+
self.state.chat_history.append(
|
168
|
+
ChatMessage(
|
169
|
+
role="FUNCTION",
|
170
|
+
content=StringChatMessageContent(value=json.dumps(result, cls=DefaultStateEncoder)),
|
171
|
+
)
|
172
|
+
)
|
134
173
|
|
135
174
|
return self.Outputs()
|
136
175
|
|
@@ -138,7 +177,6 @@ def create_function_node(function: Callable[..., Any], tool_router_node: Type[To
|
|
138
177
|
f"FunctionNode_{function.__name__}",
|
139
178
|
(FunctionNode,),
|
140
179
|
{
|
141
|
-
"function": function,
|
142
180
|
"run": execute_function,
|
143
181
|
"__module__": __name__,
|
144
182
|
},
|
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Option
|
|
10
10
|
from vellum.workflows.constants import undefined
|
11
11
|
from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
|
12
12
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
13
|
-
from vellum.workflows.edges.edge import Edge
|
14
13
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
15
14
|
from vellum.workflows.events import (
|
16
15
|
NodeExecutionFulfilledEvent,
|
@@ -143,7 +142,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
143
142
|
self._workflow_event_inner_queue: Queue[WorkflowEvent] = Queue()
|
144
143
|
|
145
144
|
self._max_concurrency = max_concurrency
|
146
|
-
self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[
|
145
|
+
self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[UUID]]] = Queue()
|
147
146
|
|
148
147
|
# This queue is responsible for sending events from WorkflowRunner to the background thread
|
149
148
|
# for user defined emitters
|
@@ -389,7 +388,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
389
388
|
):
|
390
389
|
self._run_work_item(node, span_id)
|
391
390
|
|
392
|
-
def _handle_invoked_ports(
|
391
|
+
def _handle_invoked_ports(
|
392
|
+
self,
|
393
|
+
state: StateType,
|
394
|
+
ports: Optional[Iterable[Port]],
|
395
|
+
invoked_by: Optional[UUID],
|
396
|
+
) -> None:
|
393
397
|
if not ports:
|
394
398
|
return
|
395
399
|
|
@@ -402,9 +406,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
402
406
|
next_state = state
|
403
407
|
|
404
408
|
if self._max_concurrency:
|
405
|
-
self._concurrency_queue.put((next_state, edge.to_node,
|
409
|
+
self._concurrency_queue.put((next_state, edge.to_node, invoked_by))
|
406
410
|
else:
|
407
|
-
self._run_node_if_ready(next_state, edge.to_node,
|
411
|
+
self._run_node_if_ready(next_state, edge.to_node, invoked_by)
|
408
412
|
|
409
413
|
if self._max_concurrency:
|
410
414
|
num_nodes_to_run = self._max_concurrency - len(self._active_nodes_by_execution_id)
|
@@ -412,14 +416,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
412
416
|
if self._concurrency_queue.empty():
|
413
417
|
break
|
414
418
|
|
415
|
-
next_state, node_class,
|
416
|
-
self._run_node_if_ready(next_state, node_class,
|
419
|
+
next_state, node_class, invoked_by = self._concurrency_queue.get()
|
420
|
+
self._run_node_if_ready(next_state, node_class, invoked_by)
|
417
421
|
|
418
422
|
def _run_node_if_ready(
|
419
423
|
self,
|
420
424
|
state: StateType,
|
421
425
|
node_class: Type[BaseNode],
|
422
|
-
invoked_by: Optional[
|
426
|
+
invoked_by: Optional[UUID] = None,
|
423
427
|
) -> None:
|
424
428
|
with state.__lock__:
|
425
429
|
for descriptor in node_class.ExternalInputs:
|
@@ -482,7 +486,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
482
486
|
)
|
483
487
|
)
|
484
488
|
|
485
|
-
self._handle_invoked_ports(node.state, event.invoked_ports)
|
489
|
+
self._handle_invoked_ports(node.state, event.invoked_ports, event.span_id)
|
486
490
|
|
487
491
|
return None
|
488
492
|
|
@@ -508,7 +512,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
508
512
|
)
|
509
513
|
)
|
510
514
|
|
511
|
-
self._handle_invoked_ports(node.state, event.invoked_ports)
|
515
|
+
self._handle_invoked_ports(node.state, event.invoked_ports, event.span_id)
|
512
516
|
|
513
517
|
return None
|
514
518
|
|
vellum/workflows/state/base.py
CHANGED
@@ -93,7 +93,8 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
93
93
|
NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
|
94
94
|
NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
|
95
95
|
NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
|
96
|
-
|
96
|
+
NodeExecutionLookup = Dict[UUID, Type["BaseNode"]]
|
97
|
+
DependenciesInvoked = Dict[UUID, Set[UUID]]
|
97
98
|
|
98
99
|
|
99
100
|
class NodeExecutionCache:
|
@@ -102,11 +103,15 @@ class NodeExecutionCache:
|
|
102
103
|
_node_executions_queued: NodeExecutionsQueued
|
103
104
|
_dependencies_invoked: DependenciesInvoked
|
104
105
|
|
106
|
+
# Derived fields - no need to serialize
|
107
|
+
__node_execution_lookup__: NodeExecutionLookup
|
108
|
+
|
105
109
|
def __init__(self) -> None:
|
106
110
|
self._dependencies_invoked = defaultdict(set)
|
107
111
|
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
108
112
|
self._node_executions_initiated = defaultdict(set)
|
109
113
|
self._node_executions_queued = defaultdict(list)
|
114
|
+
self.__node_execution_lookup__ = {}
|
110
115
|
|
111
116
|
@classmethod
|
112
117
|
def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
|
@@ -124,10 +129,8 @@ class NodeExecutionCache:
|
|
124
129
|
dependencies_invoked = raw_data.get("dependencies_invoked")
|
125
130
|
if isinstance(dependencies_invoked, dict):
|
126
131
|
for execution_id, dependencies in dependencies_invoked.items():
|
127
|
-
|
128
|
-
cache._dependencies_invoked[UUID(execution_id)] =
|
129
|
-
dep_class for dep_class in dependency_classes if dep_class is not None
|
130
|
-
}
|
132
|
+
dependency_execution_ids = {UUID(dep) for dep in dependencies if is_valid_uuid(dep)}
|
133
|
+
cache._dependencies_invoked[UUID(execution_id)] = dependency_execution_ids
|
131
134
|
|
132
135
|
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
133
136
|
if isinstance(node_executions_fulfilled, dict):
|
@@ -151,6 +154,10 @@ class NodeExecutionCache:
|
|
151
154
|
{UUID(execution_id) for execution_id in execution_ids}
|
152
155
|
)
|
153
156
|
|
157
|
+
for node_class, execution_ids in cache._node_executions_initiated.items():
|
158
|
+
for execution_id in execution_ids:
|
159
|
+
cache.__node_execution_lookup__[execution_id] = node_class
|
160
|
+
|
154
161
|
node_executions_queued = raw_data.get("node_executions_queued")
|
155
162
|
if isinstance(node_executions_queued, dict):
|
156
163
|
for node, execution_ids in node_executions_queued.items():
|
@@ -166,18 +173,29 @@ class NodeExecutionCache:
|
|
166
173
|
self,
|
167
174
|
execution_id: UUID,
|
168
175
|
node: Type["BaseNode"],
|
169
|
-
|
176
|
+
invoked_by: UUID,
|
170
177
|
dependencies: Set["Type[BaseNode]"],
|
171
178
|
) -> None:
|
172
|
-
self._dependencies_invoked[execution_id].add(
|
173
|
-
|
174
|
-
self.
|
179
|
+
self._dependencies_invoked[execution_id].add(invoked_by)
|
180
|
+
invoked_node_classes = {
|
181
|
+
self.__node_execution_lookup__[dep]
|
182
|
+
for dep in self._dependencies_invoked[execution_id]
|
183
|
+
if dep in self.__node_execution_lookup__
|
184
|
+
}
|
185
|
+
if len(invoked_node_classes) != len(dependencies):
|
186
|
+
return
|
187
|
+
|
188
|
+
if any(dep not in invoked_node_classes for dep in dependencies):
|
189
|
+
return
|
190
|
+
|
191
|
+
self._node_executions_queued[node].remove(execution_id)
|
175
192
|
|
176
193
|
def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
|
177
194
|
return execution_id in self._node_executions_initiated[node]
|
178
195
|
|
179
196
|
def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
180
197
|
self._node_executions_initiated[node].add(execution_id)
|
198
|
+
self.__node_execution_lookup__[execution_id] = node
|
181
199
|
|
182
200
|
def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
183
201
|
self._node_executions_fulfilled[node].push(execution_id)
|
@@ -188,7 +206,7 @@ class NodeExecutionCache:
|
|
188
206
|
def dump(self) -> Dict[str, Any]:
|
189
207
|
return {
|
190
208
|
"dependencies_invoked": {
|
191
|
-
str(execution_id): [str(dep
|
209
|
+
str(execution_id): [str(dep) for dep in dependencies]
|
192
210
|
for execution_id, dependencies in self._dependencies_invoked.items()
|
193
211
|
},
|
194
212
|
"node_executions_initiated": {
|
@@ -66,7 +66,11 @@ class DefaultStateEncoder(JSONEncoder):
|
|
66
66
|
except Exception:
|
67
67
|
source_code = f"<source code not available for {obj.__name__}>"
|
68
68
|
|
69
|
-
return {
|
69
|
+
return {
|
70
|
+
"type": "CODE_EXECUTION",
|
71
|
+
"definition": function_definition,
|
72
|
+
"src": source_code,
|
73
|
+
}
|
70
74
|
|
71
75
|
if obj.__class__ in self.encoders:
|
72
76
|
return self.encoders[obj.__class__](obj)
|