vellum-ai 0.14.55__py3-none-any.whl → 0.14.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/workflows/nodes/bases/base.py +16 -8
- vellum/workflows/nodes/core/retry_node/node.py +6 -0
- vellum/workflows/nodes/displayable/api_node/node.py +8 -1
- vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +66 -3
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +14 -10
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +3 -1
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +67 -23
- vellum/workflows/runner/runner.py +14 -10
- vellum/workflows/state/base.py +28 -10
- vellum/workflows/state/encoder.py +12 -1
- vellum/workflows/utils/functions.py +42 -1
- vellum/workflows/utils/tests/test_functions.py +156 -1
- vellum/workflows/workflows/tests/test_base_workflow.py +4 -4
- {vellum_ai-0.14.55.dist-info → vellum_ai-0.14.57.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.55.dist-info → vellum_ai-0.14.57.dist-info}/RECORD +22 -21
- vellum_ee/workflows/display/nodes/vellum/tests/test_tool_calling_node.py +118 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +265 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +14 -10
- {vellum_ai-0.14.55.dist-info → vellum_ai-0.14.57.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.55.dist-info → vellum_ai-0.14.57.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.55.dist-info → vellum_ai-0.14.57.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.57",
|
22
22
|
}
|
23
23
|
headers["X-API-KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -8,7 +8,6 @@ from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, Typ
|
|
8
8
|
from vellum.workflows.constants import undefined
|
9
9
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
10
10
|
from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
|
11
|
-
from vellum.workflows.edges.edge import Edge
|
12
11
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
13
12
|
from vellum.workflows.exceptions import NodeException
|
14
13
|
from vellum.workflows.graph import Graph
|
@@ -317,8 +316,15 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
317
316
|
"""
|
318
317
|
# Check if all dependencies have invoked this node
|
319
318
|
dependencies_invoked = state.meta.node_execution_cache._dependencies_invoked.get(node_span_id, set())
|
320
|
-
|
321
|
-
|
319
|
+
node_classes_invoked = {
|
320
|
+
state.meta.node_execution_cache.__node_execution_lookup__[dep]
|
321
|
+
for dep in dependencies_invoked
|
322
|
+
if dep in state.meta.node_execution_cache.__node_execution_lookup__
|
323
|
+
}
|
324
|
+
if len(node_classes_invoked) != len(dependencies):
|
325
|
+
return False
|
326
|
+
|
327
|
+
all_deps_invoked = all(dep in node_classes_invoked for dep in dependencies)
|
322
328
|
return all_deps_invoked
|
323
329
|
|
324
330
|
raise NodeException(
|
@@ -328,7 +334,7 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
328
334
|
|
329
335
|
@classmethod
|
330
336
|
def _queue_node_execution(
|
331
|
-
cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: Optional[
|
337
|
+
cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: Optional[UUID] = None
|
332
338
|
) -> UUID:
|
333
339
|
"""
|
334
340
|
Queues a future execution of a node, returning the span id of the execution.
|
@@ -341,19 +347,21 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
341
347
|
if not invoked_by:
|
342
348
|
return execution_id
|
343
349
|
|
350
|
+
if invoked_by not in state.meta.node_execution_cache.__node_execution_lookup__:
|
351
|
+
return execution_id
|
352
|
+
|
344
353
|
if cls.merge_behavior not in {MergeBehavior.AWAIT_ANY, MergeBehavior.AWAIT_ALL}:
|
345
354
|
return execution_id
|
346
355
|
|
347
|
-
source_node = invoked_by.from_port.node_class
|
348
356
|
for queued_node_execution_id in state.meta.node_execution_cache._node_executions_queued[cls.node_class]:
|
349
|
-
if
|
357
|
+
if invoked_by not in state.meta.node_execution_cache._dependencies_invoked[queued_node_execution_id]:
|
350
358
|
state.meta.node_execution_cache._invoke_dependency(
|
351
|
-
queued_node_execution_id, cls.node_class,
|
359
|
+
queued_node_execution_id, cls.node_class, invoked_by, dependencies
|
352
360
|
)
|
353
361
|
return queued_node_execution_id
|
354
362
|
|
355
363
|
state.meta.node_execution_cache._node_executions_queued[cls.node_class].append(execution_id)
|
356
|
-
state.meta.node_execution_cache._invoke_dependency(execution_id, cls.node_class,
|
364
|
+
state.meta.node_execution_cache._invoke_dependency(execution_id, cls.node_class, invoked_by, dependencies)
|
357
365
|
return execution_id
|
358
366
|
|
359
367
|
class Execution(metaclass=_BaseNodeExecutionMeta):
|
@@ -70,6 +70,12 @@ class RetryNode(BaseAdornmentNode[StateType], Generic[StateType]):
|
|
70
70
|
|
71
71
|
for output_descriptor, output_value in event.outputs:
|
72
72
|
setattr(node_outputs, output_descriptor.name, output_value)
|
73
|
+
|
74
|
+
if self.__wrapped_node__:
|
75
|
+
inner_desc = getattr(self.__wrapped_node__.Outputs, output_descriptor.name, None)
|
76
|
+
if inner_desc:
|
77
|
+
self.state.meta.node_outputs[inner_desc] = output_value
|
78
|
+
|
73
79
|
elif event.name == "workflow.execution.paused":
|
74
80
|
exception = NodeException(
|
75
81
|
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
@@ -47,11 +47,18 @@ class APINode(BaseAPINode):
|
|
47
47
|
self.bearer_token_value, VellumSecret
|
48
48
|
):
|
49
49
|
bearer_token = self.bearer_token_value
|
50
|
+
|
51
|
+
final_headers = {**headers, **header_overrides}
|
52
|
+
|
53
|
+
vellum_client_wrapper = self._context.vellum_client._client_wrapper
|
54
|
+
if self.url.startswith(vellum_client_wrapper._environment.default) and "X-API-Key" not in final_headers:
|
55
|
+
final_headers["X-API-Key"] = vellum_client_wrapper.api_key
|
56
|
+
|
50
57
|
return self._run(
|
51
58
|
method=self.method,
|
52
59
|
url=self.url,
|
53
60
|
data=self.data,
|
54
61
|
json=self.json,
|
55
|
-
headers=
|
62
|
+
headers=final_headers,
|
56
63
|
bearer_token=bearer_token,
|
57
64
|
)
|
@@ -19,7 +19,7 @@ def test_run_workflow__secrets(vellum_client):
|
|
19
19
|
class SimpleBaseAPINode(APINode):
|
20
20
|
method = APIRequestMethod.POST
|
21
21
|
authorization_type = AuthorizationType.BEARER_TOKEN
|
22
|
-
url = "https://
|
22
|
+
url = "https://example.vellum.ai"
|
23
23
|
json = {
|
24
24
|
"key": "value",
|
25
25
|
}
|
@@ -32,7 +32,9 @@ def test_run_workflow__secrets(vellum_client):
|
|
32
32
|
terminal = node.run()
|
33
33
|
|
34
34
|
assert vellum_client.execute_api.call_count == 1
|
35
|
+
assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
|
35
36
|
assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
|
37
|
+
assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
|
36
38
|
bearer_token = vellum_client.execute_api.call_args.kwargs["bearer_token"]
|
37
39
|
assert bearer_token == ClientVellumSecret(name="secret")
|
38
40
|
assert terminal.headers == {"X-Response-Header": "bar"}
|
@@ -45,7 +47,7 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
|
|
45
47
|
class SimpleAPINode(APINode):
|
46
48
|
method = APIRequestMethod.GET
|
47
49
|
authorization_type = AuthorizationType.BEARER_TOKEN
|
48
|
-
url = "https://
|
50
|
+
url = "https://example.vellum.ai"
|
49
51
|
json = {
|
50
52
|
"key": "value",
|
51
53
|
}
|
@@ -65,7 +67,9 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
|
|
65
67
|
|
66
68
|
# AND the API call should have been made
|
67
69
|
assert vellum_client.execute_api.call_count == 1
|
70
|
+
assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
|
68
71
|
assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
|
72
|
+
assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
|
69
73
|
|
70
74
|
|
71
75
|
def test_api_node_defaults_to_get_method(vellum_client):
|
@@ -80,7 +84,7 @@ def test_api_node_defaults_to_get_method(vellum_client):
|
|
80
84
|
# AND an API node without a method specified
|
81
85
|
class SimpleAPINodeWithoutMethod(APINode):
|
82
86
|
authorization_type = AuthorizationType.BEARER_TOKEN
|
83
|
-
url = "https://
|
87
|
+
url = "https://example.vellum.ai"
|
84
88
|
headers = {
|
85
89
|
"X-Test-Header": "foo",
|
86
90
|
}
|
@@ -95,3 +99,62 @@ def test_api_node_defaults_to_get_method(vellum_client):
|
|
95
99
|
assert vellum_client.execute_api.call_count == 1
|
96
100
|
method = vellum_client.execute_api.call_args.kwargs["method"]
|
97
101
|
assert method == APIRequestMethod.GET.value
|
102
|
+
|
103
|
+
|
104
|
+
def test_api_node__detects_client_environment_urls__adds_token(mock_httpx_transport, mock_requests, monkeypatch):
|
105
|
+
# GIVEN an API node with a URL pointing back to Vellum
|
106
|
+
class SimpleAPINodeToVellum(APINode):
|
107
|
+
url = "https://api.vellum.ai"
|
108
|
+
|
109
|
+
# AND a mock request sent to the Vellum API would return a 200
|
110
|
+
mock_response = mock_requests.get(
|
111
|
+
"https://api.vellum.ai",
|
112
|
+
status_code=200,
|
113
|
+
json={"data": [1, 2, 3]},
|
114
|
+
)
|
115
|
+
|
116
|
+
# AND an api key is set
|
117
|
+
monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
|
118
|
+
|
119
|
+
# WHEN we run the node
|
120
|
+
node = SimpleAPINodeToVellum()
|
121
|
+
node.run()
|
122
|
+
|
123
|
+
# THEN the execute_api method should not have been called
|
124
|
+
mock_httpx_transport.handle_request.assert_not_called()
|
125
|
+
|
126
|
+
# AND the vellum API should have been called with the correct headers
|
127
|
+
assert mock_response.last_request
|
128
|
+
assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-1234"
|
129
|
+
|
130
|
+
|
131
|
+
def test_api_node__detects_client_environment_urls__does_not_override_headers(
|
132
|
+
mock_httpx_transport, mock_requests, monkeypatch
|
133
|
+
):
|
134
|
+
# GIVEN an API node with a URL pointing back to Vellum
|
135
|
+
class SimpleAPINodeToVellum(APINode):
|
136
|
+
url = "https://api.vellum.ai"
|
137
|
+
headers = {
|
138
|
+
"X-API-Key": "vellum-api-key-5678",
|
139
|
+
}
|
140
|
+
|
141
|
+
# AND a mock request sent to the Vellum API would return a 200
|
142
|
+
mock_response = mock_requests.get(
|
143
|
+
"https://api.vellum.ai",
|
144
|
+
status_code=200,
|
145
|
+
json={"data": [1, 2, 3]},
|
146
|
+
)
|
147
|
+
|
148
|
+
# AND an api key is set
|
149
|
+
monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
|
150
|
+
|
151
|
+
# WHEN we run the node
|
152
|
+
node = SimpleAPINodeToVellum()
|
153
|
+
node.run()
|
154
|
+
|
155
|
+
# THEN the execute_api method should not have been called
|
156
|
+
mock_httpx_transport.handle_request.assert_not_called()
|
157
|
+
|
158
|
+
# AND the vellum API should have been called with the correct headers
|
159
|
+
assert mock_response.last_request
|
160
|
+
assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-5678"
|
@@ -30,8 +30,8 @@ from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePrompt
|
|
30
30
|
from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
|
31
31
|
from vellum.workflows.outputs import BaseOutput
|
32
32
|
from vellum.workflows.types import MergeBehavior
|
33
|
-
from vellum.workflows.types.generics import StateType
|
34
|
-
from vellum.workflows.utils.functions import compile_function_definition
|
33
|
+
from vellum.workflows.types.generics import StateType, is_workflow_class
|
34
|
+
from vellum.workflows.utils.functions import compile_function_definition, compile_workflow_function_definition
|
35
35
|
|
36
36
|
|
37
37
|
class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
@@ -97,14 +97,18 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
97
97
|
"execution_context": execution_context.model_dump(mode="json"),
|
98
98
|
**request_options.get("additional_body_parameters", {}),
|
99
99
|
}
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
]
|
105
|
-
|
106
|
-
|
107
|
-
|
100
|
+
|
101
|
+
normalized_functions: Optional[List[FunctionDefinition]] = None
|
102
|
+
|
103
|
+
if self.functions:
|
104
|
+
normalized_functions = []
|
105
|
+
for function in self.functions:
|
106
|
+
if isinstance(function, FunctionDefinition):
|
107
|
+
normalized_functions.append(function)
|
108
|
+
elif is_workflow_class(function):
|
109
|
+
normalized_functions.append(compile_workflow_function_definition(function))
|
110
|
+
else:
|
111
|
+
normalized_functions.append(compile_function_definition(function))
|
108
112
|
|
109
113
|
if self.settings and not self.settings.stream_enabled:
|
110
114
|
# This endpoint is returning a single event, so we need to wrap it in a generator
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from collections.abc import Callable
|
2
2
|
from typing import Any, ClassVar, List, Optional
|
3
3
|
|
4
|
+
from pydash import snake_case
|
5
|
+
|
4
6
|
from vellum import ChatMessage, PromptBlock
|
5
7
|
from vellum.workflows.context import execution_context, get_parent_context
|
6
8
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -101,7 +103,7 @@ class ToolCallingNode(BaseNode):
|
|
101
103
|
)
|
102
104
|
|
103
105
|
self._function_nodes = {
|
104
|
-
function.__name__: create_function_node(
|
106
|
+
snake_case(function.__name__): create_function_node(
|
105
107
|
function=function,
|
106
108
|
tool_router_node=self.tool_router_node,
|
107
109
|
)
|
@@ -2,22 +2,30 @@ from collections.abc import Callable
|
|
2
2
|
import json
|
3
3
|
from typing import Any, Iterator, List, Optional, Type, cast
|
4
4
|
|
5
|
-
from
|
5
|
+
from pydash import snake_case
|
6
|
+
|
7
|
+
from vellum import ChatMessage, PromptBlock
|
6
8
|
from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
|
7
9
|
from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
|
10
|
+
from vellum.client.types.string_chat_message_content import StringChatMessageContent
|
8
11
|
from vellum.client.types.variable_prompt_block import VariablePromptBlock
|
12
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
13
|
+
from vellum.workflows.exceptions import NodeException
|
14
|
+
from vellum.workflows.inputs.base import BaseInputs
|
9
15
|
from vellum.workflows.nodes.bases import BaseNode
|
10
16
|
from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
|
11
17
|
from vellum.workflows.outputs.base import BaseOutput
|
12
18
|
from vellum.workflows.ports.port import Port
|
13
19
|
from vellum.workflows.references.lazy import LazyReference
|
20
|
+
from vellum.workflows.state.encoder import DefaultStateEncoder
|
14
21
|
from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
|
22
|
+
from vellum.workflows.types.generics import is_workflow_class
|
15
23
|
|
16
24
|
|
17
25
|
class FunctionNode(BaseNode):
|
18
26
|
"""Node that executes a specific function."""
|
19
27
|
|
20
|
-
|
28
|
+
pass
|
21
29
|
|
22
30
|
|
23
31
|
class ToolRouterNode(InlinePromptNode):
|
@@ -57,26 +65,32 @@ def create_tool_router_node(
|
|
57
65
|
functions: List[Callable[..., Any]],
|
58
66
|
prompt_inputs: Optional[EntityInputsInterface],
|
59
67
|
) -> Type[ToolRouterNode]:
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
68
|
+
if functions and len(functions) > 0:
|
69
|
+
# If we have functions, create dynamic ports for each function
|
70
|
+
Ports = type("Ports", (), {})
|
71
|
+
for function in functions:
|
72
|
+
function_name = snake_case(function.__name__)
|
73
|
+
|
74
|
+
# Avoid using lambda to capture function_name
|
75
|
+
# lambda will capture the function_name by reference,
|
76
|
+
# and if the function_name is changed, the port_condition will also change.
|
77
|
+
def create_port_condition(fn_name):
|
78
|
+
return LazyReference(
|
79
|
+
lambda: (
|
80
|
+
node.Outputs.results[0]["type"].equals("FUNCTION_CALL")
|
81
|
+
& node.Outputs.results[0]["value"]["name"].equals(fn_name)
|
82
|
+
)
|
72
83
|
)
|
73
|
-
)
|
74
84
|
|
75
|
-
|
76
|
-
|
77
|
-
|
85
|
+
port_condition = create_port_condition(function_name)
|
86
|
+
port = Port.on_if(port_condition)
|
87
|
+
setattr(Ports, function_name, port)
|
78
88
|
|
79
|
-
|
89
|
+
# Add the else port for when no function conditions match
|
90
|
+
setattr(Ports, "default", Port.on_else())
|
91
|
+
else:
|
92
|
+
# If no functions exist, create a simple Ports class with just a default port
|
93
|
+
Ports = type("Ports", (), {"default": Port(default=True)})
|
80
94
|
|
81
95
|
# Add a chat history block to blocks
|
82
96
|
blocks.append(
|
@@ -121,10 +135,41 @@ def create_function_node(function: Callable[..., Any], tool_router_node: Type[To
|
|
121
135
|
outputs = json.loads(outputs)
|
122
136
|
arguments = outputs["arguments"]
|
123
137
|
|
124
|
-
# Call the
|
125
|
-
|
138
|
+
# Call the function based on its type
|
139
|
+
if is_workflow_class(function):
|
140
|
+
# Dynamically define an Inputs subclass of BaseInputs
|
141
|
+
Inputs = type(
|
142
|
+
"Inputs",
|
143
|
+
(BaseInputs,),
|
144
|
+
{"__annotations__": {k: type(v) for k, v in arguments.items()}},
|
145
|
+
)
|
146
|
+
|
147
|
+
# Create an instance with arguments
|
148
|
+
inputs_instance = Inputs(**arguments)
|
126
149
|
|
127
|
-
|
150
|
+
workflow = function()
|
151
|
+
terminal_event = workflow.run(
|
152
|
+
inputs=inputs_instance,
|
153
|
+
)
|
154
|
+
if terminal_event.name == "workflow.execution.paused":
|
155
|
+
raise NodeException(
|
156
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
157
|
+
message="Subworkflow unexpectedly paused",
|
158
|
+
)
|
159
|
+
elif terminal_event.name == "workflow.execution.fulfilled":
|
160
|
+
result = terminal_event.outputs
|
161
|
+
elif terminal_event.name == "workflow.execution.rejected":
|
162
|
+
raise Exception(f"Workflow execution rejected: {terminal_event.error}")
|
163
|
+
else:
|
164
|
+
# If it's a regular callable, call it directly
|
165
|
+
result = function(**arguments)
|
166
|
+
|
167
|
+
self.state.chat_history.append(
|
168
|
+
ChatMessage(
|
169
|
+
role="FUNCTION",
|
170
|
+
content=StringChatMessageContent(value=json.dumps(result, cls=DefaultStateEncoder)),
|
171
|
+
)
|
172
|
+
)
|
128
173
|
|
129
174
|
return self.Outputs()
|
130
175
|
|
@@ -132,7 +177,6 @@ def create_function_node(function: Callable[..., Any], tool_router_node: Type[To
|
|
132
177
|
f"FunctionNode_{function.__name__}",
|
133
178
|
(FunctionNode,),
|
134
179
|
{
|
135
|
-
"function": function,
|
136
180
|
"run": execute_function,
|
137
181
|
"__module__": __name__,
|
138
182
|
},
|
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Option
|
|
10
10
|
from vellum.workflows.constants import undefined
|
11
11
|
from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
|
12
12
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
13
|
-
from vellum.workflows.edges.edge import Edge
|
14
13
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
15
14
|
from vellum.workflows.events import (
|
16
15
|
NodeExecutionFulfilledEvent,
|
@@ -143,7 +142,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
143
142
|
self._workflow_event_inner_queue: Queue[WorkflowEvent] = Queue()
|
144
143
|
|
145
144
|
self._max_concurrency = max_concurrency
|
146
|
-
self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[
|
145
|
+
self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[UUID]]] = Queue()
|
147
146
|
|
148
147
|
# This queue is responsible for sending events from WorkflowRunner to the background thread
|
149
148
|
# for user defined emitters
|
@@ -389,7 +388,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
389
388
|
):
|
390
389
|
self._run_work_item(node, span_id)
|
391
390
|
|
392
|
-
def _handle_invoked_ports(
|
391
|
+
def _handle_invoked_ports(
|
392
|
+
self,
|
393
|
+
state: StateType,
|
394
|
+
ports: Optional[Iterable[Port]],
|
395
|
+
invoked_by: Optional[UUID],
|
396
|
+
) -> None:
|
393
397
|
if not ports:
|
394
398
|
return
|
395
399
|
|
@@ -402,9 +406,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
402
406
|
next_state = state
|
403
407
|
|
404
408
|
if self._max_concurrency:
|
405
|
-
self._concurrency_queue.put((next_state, edge.to_node,
|
409
|
+
self._concurrency_queue.put((next_state, edge.to_node, invoked_by))
|
406
410
|
else:
|
407
|
-
self._run_node_if_ready(next_state, edge.to_node,
|
411
|
+
self._run_node_if_ready(next_state, edge.to_node, invoked_by)
|
408
412
|
|
409
413
|
if self._max_concurrency:
|
410
414
|
num_nodes_to_run = self._max_concurrency - len(self._active_nodes_by_execution_id)
|
@@ -412,14 +416,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
412
416
|
if self._concurrency_queue.empty():
|
413
417
|
break
|
414
418
|
|
415
|
-
next_state, node_class,
|
416
|
-
self._run_node_if_ready(next_state, node_class,
|
419
|
+
next_state, node_class, invoked_by = self._concurrency_queue.get()
|
420
|
+
self._run_node_if_ready(next_state, node_class, invoked_by)
|
417
421
|
|
418
422
|
def _run_node_if_ready(
|
419
423
|
self,
|
420
424
|
state: StateType,
|
421
425
|
node_class: Type[BaseNode],
|
422
|
-
invoked_by: Optional[
|
426
|
+
invoked_by: Optional[UUID] = None,
|
423
427
|
) -> None:
|
424
428
|
with state.__lock__:
|
425
429
|
for descriptor in node_class.ExternalInputs:
|
@@ -482,7 +486,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
482
486
|
)
|
483
487
|
)
|
484
488
|
|
485
|
-
self._handle_invoked_ports(node.state, event.invoked_ports)
|
489
|
+
self._handle_invoked_ports(node.state, event.invoked_ports, event.span_id)
|
486
490
|
|
487
491
|
return None
|
488
492
|
|
@@ -508,7 +512,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
508
512
|
)
|
509
513
|
)
|
510
514
|
|
511
|
-
self._handle_invoked_ports(node.state, event.invoked_ports)
|
515
|
+
self._handle_invoked_ports(node.state, event.invoked_ports, event.span_id)
|
512
516
|
|
513
517
|
return None
|
514
518
|
|
vellum/workflows/state/base.py
CHANGED
@@ -93,7 +93,8 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
93
93
|
NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
|
94
94
|
NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
|
95
95
|
NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
|
96
|
-
|
96
|
+
NodeExecutionLookup = Dict[UUID, Type["BaseNode"]]
|
97
|
+
DependenciesInvoked = Dict[UUID, Set[UUID]]
|
97
98
|
|
98
99
|
|
99
100
|
class NodeExecutionCache:
|
@@ -102,11 +103,15 @@ class NodeExecutionCache:
|
|
102
103
|
_node_executions_queued: NodeExecutionsQueued
|
103
104
|
_dependencies_invoked: DependenciesInvoked
|
104
105
|
|
106
|
+
# Derived fields - no need to serialize
|
107
|
+
__node_execution_lookup__: NodeExecutionLookup
|
108
|
+
|
105
109
|
def __init__(self) -> None:
|
106
110
|
self._dependencies_invoked = defaultdict(set)
|
107
111
|
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
108
112
|
self._node_executions_initiated = defaultdict(set)
|
109
113
|
self._node_executions_queued = defaultdict(list)
|
114
|
+
self.__node_execution_lookup__ = {}
|
110
115
|
|
111
116
|
@classmethod
|
112
117
|
def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
|
@@ -124,10 +129,8 @@ class NodeExecutionCache:
|
|
124
129
|
dependencies_invoked = raw_data.get("dependencies_invoked")
|
125
130
|
if isinstance(dependencies_invoked, dict):
|
126
131
|
for execution_id, dependencies in dependencies_invoked.items():
|
127
|
-
|
128
|
-
cache._dependencies_invoked[UUID(execution_id)] =
|
129
|
-
dep_class for dep_class in dependency_classes if dep_class is not None
|
130
|
-
}
|
132
|
+
dependency_execution_ids = {UUID(dep) for dep in dependencies if is_valid_uuid(dep)}
|
133
|
+
cache._dependencies_invoked[UUID(execution_id)] = dependency_execution_ids
|
131
134
|
|
132
135
|
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
133
136
|
if isinstance(node_executions_fulfilled, dict):
|
@@ -151,6 +154,10 @@ class NodeExecutionCache:
|
|
151
154
|
{UUID(execution_id) for execution_id in execution_ids}
|
152
155
|
)
|
153
156
|
|
157
|
+
for node_class, execution_ids in cache._node_executions_initiated.items():
|
158
|
+
for execution_id in execution_ids:
|
159
|
+
cache.__node_execution_lookup__[execution_id] = node_class
|
160
|
+
|
154
161
|
node_executions_queued = raw_data.get("node_executions_queued")
|
155
162
|
if isinstance(node_executions_queued, dict):
|
156
163
|
for node, execution_ids in node_executions_queued.items():
|
@@ -166,18 +173,29 @@ class NodeExecutionCache:
|
|
166
173
|
self,
|
167
174
|
execution_id: UUID,
|
168
175
|
node: Type["BaseNode"],
|
169
|
-
|
176
|
+
invoked_by: UUID,
|
170
177
|
dependencies: Set["Type[BaseNode]"],
|
171
178
|
) -> None:
|
172
|
-
self._dependencies_invoked[execution_id].add(
|
173
|
-
|
174
|
-
self.
|
179
|
+
self._dependencies_invoked[execution_id].add(invoked_by)
|
180
|
+
invoked_node_classes = {
|
181
|
+
self.__node_execution_lookup__[dep]
|
182
|
+
for dep in self._dependencies_invoked[execution_id]
|
183
|
+
if dep in self.__node_execution_lookup__
|
184
|
+
}
|
185
|
+
if len(invoked_node_classes) != len(dependencies):
|
186
|
+
return
|
187
|
+
|
188
|
+
if any(dep not in invoked_node_classes for dep in dependencies):
|
189
|
+
return
|
190
|
+
|
191
|
+
self._node_executions_queued[node].remove(execution_id)
|
175
192
|
|
176
193
|
def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
|
177
194
|
return execution_id in self._node_executions_initiated[node]
|
178
195
|
|
179
196
|
def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
180
197
|
self._node_executions_initiated[node].add(execution_id)
|
198
|
+
self.__node_execution_lookup__[execution_id] = node
|
181
199
|
|
182
200
|
def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
183
201
|
self._node_executions_fulfilled[node].push(execution_id)
|
@@ -188,7 +206,7 @@ class NodeExecutionCache:
|
|
188
206
|
def dump(self) -> Dict[str, Any]:
|
189
207
|
return {
|
190
208
|
"dependencies_invoked": {
|
191
|
-
str(execution_id): [str(dep
|
209
|
+
str(execution_id): [str(dep) for dep in dependencies]
|
192
210
|
for execution_id, dependencies in self._dependencies_invoked.items()
|
193
211
|
},
|
194
212
|
"node_executions_initiated": {
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from dataclasses import asdict, is_dataclass
|
2
2
|
from datetime import datetime
|
3
3
|
import enum
|
4
|
+
import inspect
|
4
5
|
from json import JSONEncoder
|
5
6
|
from queue import Queue
|
6
7
|
from uuid import UUID
|
@@ -59,7 +60,17 @@ class DefaultStateEncoder(JSONEncoder):
|
|
59
60
|
return str(obj)
|
60
61
|
|
61
62
|
if callable(obj):
|
62
|
-
|
63
|
+
function_definition = compile_function_definition(obj)
|
64
|
+
try:
|
65
|
+
source_code = inspect.getsource(obj)
|
66
|
+
except Exception:
|
67
|
+
source_code = f"<source code not available for {obj.__name__}>"
|
68
|
+
|
69
|
+
return {
|
70
|
+
"type": "CODE_EXECUTION",
|
71
|
+
"definition": function_definition,
|
72
|
+
"src": source_code,
|
73
|
+
}
|
63
74
|
|
64
75
|
if obj.__class__ in self.encoders:
|
65
76
|
return self.encoders[obj.__class__](obj)
|
@@ -1,12 +1,16 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Optional, Union, get_args, get_origin
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, Union, get_args, get_origin
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
from pydantic_core import PydanticUndefined
|
7
|
+
from pydash import snake_case
|
7
8
|
|
8
9
|
from vellum.client.types.function_definition import FunctionDefinition
|
9
10
|
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
13
|
+
|
10
14
|
type_map = {
|
11
15
|
str: "string",
|
12
16
|
int: "integer",
|
@@ -108,5 +112,42 @@ def compile_function_definition(function: Callable) -> FunctionDefinition:
|
|
108
112
|
|
109
113
|
return FunctionDefinition(
|
110
114
|
name=function.__name__,
|
115
|
+
description=function.__doc__,
|
116
|
+
parameters=parameters,
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
def compile_workflow_function_definition(workflow_class: Type["BaseWorkflow"]) -> FunctionDefinition:
|
121
|
+
"""
|
122
|
+
Converts a base workflow class into our Vellum-native FunctionDefinition type.
|
123
|
+
"""
|
124
|
+
# Get the inputs class for the workflow
|
125
|
+
inputs_class = workflow_class.get_inputs_class()
|
126
|
+
vars_inputs_class = vars(inputs_class)
|
127
|
+
|
128
|
+
properties = {}
|
129
|
+
required = []
|
130
|
+
defs: dict[str, Any] = {}
|
131
|
+
|
132
|
+
for name, field_type in inputs_class.__annotations__.items():
|
133
|
+
if name.startswith("__"):
|
134
|
+
continue
|
135
|
+
|
136
|
+
properties[name] = _compile_annotation(field_type, defs)
|
137
|
+
|
138
|
+
# Check if the field has a default value
|
139
|
+
if name not in vars_inputs_class:
|
140
|
+
required.append(name)
|
141
|
+
else:
|
142
|
+
# Field has a default value
|
143
|
+
properties[name]["default"] = vars_inputs_class[name]
|
144
|
+
|
145
|
+
parameters = {"type": "object", "properties": properties, "required": required}
|
146
|
+
if defs:
|
147
|
+
parameters["$defs"] = defs
|
148
|
+
|
149
|
+
return FunctionDefinition(
|
150
|
+
name=snake_case(workflow_class.__name__),
|
151
|
+
description=workflow_class.__doc__,
|
111
152
|
parameters=parameters,
|
112
153
|
)
|