vellum-ai 0.13.21__py3-none-any.whl → 0.13.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +12 -0
- vellum/client/__init__.py +170 -5
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +12 -0
- vellum/client/types/execute_api_request_bearer_token.py +6 -0
- vellum/client/types/execute_api_request_body.py +5 -0
- vellum/client/types/execute_api_request_headers_value.py +6 -0
- vellum/client/types/execute_api_response.py +24 -0
- vellum/client/types/method_enum.py +5 -0
- vellum/client/types/vellum_secret.py +19 -0
- vellum/plugins/pydantic.py +13 -1
- vellum/types/execute_api_request_bearer_token.py +3 -0
- vellum/types/execute_api_request_body.py +3 -0
- vellum/types/execute_api_request_headers_value.py +3 -0
- vellum/types/execute_api_response.py +3 -0
- vellum/types/method_enum.py +3 -0
- vellum/types/vellum_secret.py +3 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -0
- vellum/workflows/nodes/core/map_node/node.py +1 -0
- vellum/workflows/nodes/core/retry_node/node.py +1 -0
- vellum/workflows/nodes/core/try_node/node.py +11 -7
- vellum/workflows/nodes/displayable/api_node/node.py +12 -3
- vellum/workflows/nodes/displayable/api_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +34 -0
- vellum/workflows/nodes/displayable/bases/api_node/node.py +25 -4
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +11 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +27 -12
- vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +33 -0
- vellum/workflows/nodes/displayable/code_execution_node/utils.py +14 -12
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +49 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +96 -0
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +29 -5
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +131 -0
- vellum/workflows/nodes/mocks.py +17 -0
- vellum/workflows/runner/runner.py +14 -34
- vellum/workflows/state/context.py +29 -1
- vellum/workflows/workflows/base.py +9 -6
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/RECORD +50 -31
- vellum_cli/push.py +5 -6
- vellum_cli/tests/test_push.py +81 -1
- vellum_ee/workflows/display/types.py +1 -31
- vellum_ee/workflows/display/vellum.py +1 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +46 -2
- vellum_ee/workflows/tests/test_server.py +9 -0
- vellum/workflows/types/cycle_map.py +0 -34
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from uuid import uuid4
|
3
|
-
from typing import Callable, ClassVar, Generic, Iterator, List, Optional, Tuple, Union
|
3
|
+
from typing import Callable, ClassVar, Generic, Iterator, List, Optional, Tuple, Union
|
4
4
|
|
5
5
|
from vellum import (
|
6
6
|
AdHocExecutePromptEvent,
|
@@ -16,6 +16,7 @@ from vellum import (
|
|
16
16
|
VellumVariable,
|
17
17
|
)
|
18
18
|
from vellum.client import RequestOptions
|
19
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
19
20
|
from vellum.workflows.constants import OMIT
|
20
21
|
from vellum.workflows.context import get_parent_context
|
21
22
|
from vellum.workflows.errors import WorkflowErrorCode
|
@@ -108,7 +109,14 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
108
109
|
value=input_value,
|
109
110
|
)
|
110
111
|
)
|
111
|
-
elif isinstance(input_value, list) and all(
|
112
|
+
elif isinstance(input_value, list) and all(
|
113
|
+
isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
|
114
|
+
):
|
115
|
+
chat_history = [
|
116
|
+
message if isinstance(message, ChatMessage) else ChatMessage.model_validate(message.model_dump())
|
117
|
+
for message in input_value
|
118
|
+
if isinstance(message, (ChatMessage, ChatMessageRequest))
|
119
|
+
]
|
112
120
|
input_variables.append(
|
113
121
|
VellumVariable(
|
114
122
|
# TODO: Determine whether or not we actually need an id here and if we do,
|
@@ -122,7 +130,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
122
130
|
input_values.append(
|
123
131
|
PromptRequestChatHistoryInput(
|
124
132
|
key=input_name,
|
125
|
-
value=
|
133
|
+
value=chat_history,
|
126
134
|
)
|
127
135
|
)
|
128
136
|
else:
|
@@ -1,5 +1,6 @@
|
|
1
|
+
import json
|
1
2
|
from uuid import UUID
|
2
|
-
from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Sequence, Union
|
3
|
+
from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Sequence, Union
|
3
4
|
|
4
5
|
from vellum import (
|
5
6
|
ChatHistoryInputRequest,
|
@@ -12,9 +13,11 @@ from vellum import (
|
|
12
13
|
StringInputRequest,
|
13
14
|
)
|
14
15
|
from vellum.client import RequestOptions
|
16
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
15
17
|
from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
|
16
18
|
from vellum.workflows.context import get_parent_context
|
17
19
|
from vellum.workflows.errors import WorkflowErrorCode
|
20
|
+
from vellum.workflows.events.types import default_serializer
|
18
21
|
from vellum.workflows.exceptions import NodeException
|
19
22
|
from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
|
20
23
|
from vellum.workflows.types import MergeBehavior
|
@@ -89,26 +92,38 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
89
92
|
value=input_value,
|
90
93
|
)
|
91
94
|
)
|
92
|
-
elif isinstance(input_value, list) and all(
|
95
|
+
elif isinstance(input_value, list) and all(
|
96
|
+
isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
|
97
|
+
):
|
98
|
+
chat_history = [
|
99
|
+
(
|
100
|
+
message
|
101
|
+
if isinstance(message, ChatMessageRequest)
|
102
|
+
else ChatMessageRequest.model_validate(message.model_dump())
|
103
|
+
)
|
104
|
+
for message in input_value
|
105
|
+
if isinstance(message, (ChatMessage, ChatMessageRequest))
|
106
|
+
]
|
93
107
|
compiled_inputs.append(
|
94
108
|
ChatHistoryInputRequest(
|
95
109
|
name=input_name,
|
96
|
-
value=
|
110
|
+
value=chat_history,
|
97
111
|
)
|
98
112
|
)
|
99
|
-
|
100
|
-
|
101
|
-
|
113
|
+
else:
|
114
|
+
try:
|
115
|
+
input_value = default_serializer(input_value)
|
116
|
+
except json.JSONDecodeError as e:
|
117
|
+
raise NodeException(
|
118
|
+
message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
|
119
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
120
|
+
)
|
121
|
+
|
102
122
|
compiled_inputs.append(
|
103
123
|
JsonInputRequest(
|
104
124
|
name=input_name,
|
105
|
-
value=
|
125
|
+
value=input_value,
|
106
126
|
)
|
107
127
|
)
|
108
|
-
else:
|
109
|
-
raise NodeException(
|
110
|
-
message=f"Unrecognized input type for input '{input_name}': {input_value.__class__}",
|
111
|
-
code=WorkflowErrorCode.INVALID_INPUTS,
|
112
|
-
)
|
113
128
|
|
114
129
|
return compiled_inputs
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import pytest
|
2
2
|
import os
|
3
|
+
from typing import Any
|
3
4
|
|
4
5
|
from vellum import CodeExecutorResponse, NumberVellumValue, StringInput
|
5
6
|
from vellum.client.types.code_execution_package import CodeExecutionPackage
|
@@ -413,3 +414,35 @@ name
|
|
413
414
|
Field required [type=missing, input_value={'n': 'hello', 'a': {}}, input_type=dict]\
|
414
415
|
"""
|
415
416
|
)
|
417
|
+
|
418
|
+
|
419
|
+
def test_run_workflow__run_inline__valid_dict_to_pydantic_any_type():
|
420
|
+
"""Confirm that CodeExecutionNodes can convert a dict to a Pydantic model during inline execution."""
|
421
|
+
|
422
|
+
# GIVEN a node that subclasses CodeExecutionNode that returns a dict matching Any
|
423
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, Any]):
|
424
|
+
code = """\
|
425
|
+
def main(word: str) -> dict:
|
426
|
+
return {
|
427
|
+
"name": "word",
|
428
|
+
"arguments": {},
|
429
|
+
}
|
430
|
+
"""
|
431
|
+
runtime = "PYTHON_3_11_6"
|
432
|
+
|
433
|
+
code_inputs = {
|
434
|
+
"word": "hello",
|
435
|
+
}
|
436
|
+
|
437
|
+
# WHEN we run the node
|
438
|
+
node = ExampleCodeExecutionNode()
|
439
|
+
outputs = node.run()
|
440
|
+
|
441
|
+
# THEN the node should have produced the outputs we expect
|
442
|
+
assert outputs == {
|
443
|
+
"result": {
|
444
|
+
"name": "word",
|
445
|
+
"arguments": {},
|
446
|
+
},
|
447
|
+
"log": "",
|
448
|
+
}
|
@@ -91,19 +91,21 @@ __arg__out = main({", ".join(run_args)})
|
|
91
91
|
logs = log_buffer.getvalue()
|
92
92
|
result = exec_globals["__arg__out"]
|
93
93
|
|
94
|
-
if
|
95
|
-
|
96
|
-
|
97
|
-
|
94
|
+
if output_type != Any:
|
95
|
+
if issubclass(output_type, BaseModel) and not isinstance(result, output_type):
|
96
|
+
try:
|
97
|
+
result = output_type.model_validate(result)
|
98
|
+
except ValidationError as e:
|
99
|
+
raise NodeException(
|
100
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
101
|
+
message=re.sub(r"\s+For further information visit [^\s]+", "", str(e)),
|
102
|
+
) from e
|
103
|
+
|
104
|
+
if not isinstance(result, output_type):
|
98
105
|
raise NodeException(
|
99
106
|
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
100
|
-
message=
|
101
|
-
|
102
|
-
|
103
|
-
if not isinstance(result, output_type):
|
104
|
-
raise NodeException(
|
105
|
-
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
106
|
-
message=f"Expected an output of type '{output_type.__name__}', but received '{result.__class__.__name__}'",
|
107
|
-
)
|
107
|
+
message=f"Expected an output of type '{output_type.__name__}',"
|
108
|
+
f" but received '{result.__class__.__name__}'",
|
109
|
+
)
|
108
110
|
|
109
111
|
return logs, result
|
@@ -5,6 +5,8 @@ from typing import Any, Iterator, List
|
|
5
5
|
|
6
6
|
from vellum.client.core.api_error import ApiError
|
7
7
|
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
8
|
+
from vellum.client.types.chat_message import ChatMessage
|
9
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
8
10
|
from vellum.client.types.execute_prompt_event import ExecutePromptEvent
|
9
11
|
from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
|
10
12
|
from vellum.client.types.function_call import FunctionCall
|
@@ -12,6 +14,7 @@ from vellum.client.types.function_call_vellum_value import FunctionCallVellumVal
|
|
12
14
|
from vellum.client.types.function_definition import FunctionDefinition
|
13
15
|
from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
|
14
16
|
from vellum.client.types.prompt_output import PromptOutput
|
17
|
+
from vellum.client.types.prompt_request_chat_history_input import PromptRequestChatHistoryInput
|
15
18
|
from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
|
16
19
|
from vellum.client.types.string_vellum_value import StringVellumValue
|
17
20
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -181,3 +184,49 @@ def test_inline_prompt_node__api_error__invalid_inputs_node_exception(
|
|
181
184
|
# THEN the node raises the correct NodeException
|
182
185
|
assert e.value.code == expected_code
|
183
186
|
assert e.value.message == expected_message
|
187
|
+
|
188
|
+
|
189
|
+
def test_inline_prompt_node__chat_history_inputs(vellum_adhoc_prompt_client):
|
190
|
+
# GIVEN a prompt node with a chat history input
|
191
|
+
class MyNode(InlinePromptNode):
|
192
|
+
ml_model = "gpt-4o"
|
193
|
+
blocks = []
|
194
|
+
prompt_inputs = {
|
195
|
+
"chat_history": [ChatMessageRequest(role="USER", text="Hello, how are you?")],
|
196
|
+
}
|
197
|
+
|
198
|
+
# AND a known response from invoking an inline prompt
|
199
|
+
expected_outputs: List[PromptOutput] = [
|
200
|
+
StringVellumValue(value="Great!"),
|
201
|
+
]
|
202
|
+
|
203
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
204
|
+
execution_id = str(uuid4())
|
205
|
+
events: List[ExecutePromptEvent] = [
|
206
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
207
|
+
FulfilledExecutePromptEvent(
|
208
|
+
execution_id=execution_id,
|
209
|
+
outputs=expected_outputs,
|
210
|
+
),
|
211
|
+
]
|
212
|
+
yield from events
|
213
|
+
|
214
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
215
|
+
|
216
|
+
# WHEN the node is run
|
217
|
+
events = list(MyNode().run())
|
218
|
+
|
219
|
+
# THEN the prompt is executed with the correct inputs
|
220
|
+
assert events[-1].value == "Great!"
|
221
|
+
|
222
|
+
# AND the prompt is executed with the correct inputs
|
223
|
+
mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
|
224
|
+
assert mock_api.call_count == 1
|
225
|
+
assert mock_api.call_args.kwargs["input_values"] == [
|
226
|
+
PromptRequestChatHistoryInput(
|
227
|
+
key="chat_history",
|
228
|
+
type="CHAT_HISTORY",
|
229
|
+
value=[ChatMessage(role="USER", text="Hello, how are you?")],
|
230
|
+
),
|
231
|
+
]
|
232
|
+
assert mock_api.call_args.kwargs["input_variables"][0].type == "CHAT_HISTORY"
|
File without changes
|
@@ -0,0 +1,96 @@
|
|
1
|
+
import pytest
|
2
|
+
from uuid import uuid4
|
3
|
+
from typing import Any, Iterator, List
|
4
|
+
|
5
|
+
from vellum.client.types.chat_history_input_request import ChatHistoryInputRequest
|
6
|
+
from vellum.client.types.chat_message import ChatMessage
|
7
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
8
|
+
from vellum.client.types.execute_prompt_event import ExecutePromptEvent
|
9
|
+
from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
|
10
|
+
from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
|
11
|
+
from vellum.client.types.json_input_request import JsonInputRequest
|
12
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
13
|
+
from vellum.workflows.nodes.displayable.prompt_deployment_node.node import PromptDeploymentNode
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.mark.parametrize("ChatMessageClass", [ChatMessageRequest, ChatMessage])
|
17
|
+
def test_run_node__chat_history_input(vellum_client, ChatMessageClass):
|
18
|
+
"""Confirm that we can successfully invoke a Prompt Deployment Node that uses Chat History Inputs"""
|
19
|
+
|
20
|
+
# GIVEN a Prompt Deployment Node
|
21
|
+
class ExamplePromptDeploymentNode(PromptDeploymentNode):
|
22
|
+
deployment = "example_prompt_deployment"
|
23
|
+
prompt_inputs = {
|
24
|
+
"chat_history": [ChatMessageClass(role="USER", text="Hello, how are you?")],
|
25
|
+
}
|
26
|
+
|
27
|
+
# AND we know what the Prompt Deployment will respond with
|
28
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
29
|
+
execution_id = str(uuid4())
|
30
|
+
events: List[ExecutePromptEvent] = [
|
31
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
32
|
+
FulfilledExecutePromptEvent(
|
33
|
+
execution_id=execution_id,
|
34
|
+
outputs=[
|
35
|
+
StringVellumValue(value="Great!"),
|
36
|
+
],
|
37
|
+
),
|
38
|
+
]
|
39
|
+
yield from events
|
40
|
+
|
41
|
+
vellum_client.execute_prompt_stream.side_effect = generate_prompt_events
|
42
|
+
|
43
|
+
# WHEN we run the node
|
44
|
+
node = ExamplePromptDeploymentNode()
|
45
|
+
events = list(node.run())
|
46
|
+
|
47
|
+
# THEN the node should have completed successfully
|
48
|
+
assert events[-1].value == "Great!"
|
49
|
+
|
50
|
+
# AND we should have invoked the Prompt Deployment with the expected inputs
|
51
|
+
call_kwargs = vellum_client.execute_prompt_stream.call_args.kwargs
|
52
|
+
assert call_kwargs["inputs"] == [
|
53
|
+
ChatHistoryInputRequest(
|
54
|
+
name="chat_history", value=[ChatMessageRequest(role="USER", text="Hello, how are you?")]
|
55
|
+
),
|
56
|
+
]
|
57
|
+
|
58
|
+
|
59
|
+
def test_run_node__any_array_input(vellum_client):
|
60
|
+
"""Confirm that we can successfully invoke a Prompt Deployment Node that uses any array input"""
|
61
|
+
|
62
|
+
# GIVEN a Prompt Deployment Node
|
63
|
+
class ExamplePromptDeploymentNode(PromptDeploymentNode):
|
64
|
+
deployment = "example_prompt_deployment"
|
65
|
+
prompt_inputs = {
|
66
|
+
"fruits": ["apple", "banana", "cherry"],
|
67
|
+
}
|
68
|
+
|
69
|
+
# AND we know what the Prompt Deployment will respond with
|
70
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
71
|
+
execution_id = str(uuid4())
|
72
|
+
events: List[ExecutePromptEvent] = [
|
73
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
74
|
+
FulfilledExecutePromptEvent(
|
75
|
+
execution_id=execution_id,
|
76
|
+
outputs=[
|
77
|
+
StringVellumValue(value="Great!"),
|
78
|
+
],
|
79
|
+
),
|
80
|
+
]
|
81
|
+
yield from events
|
82
|
+
|
83
|
+
vellum_client.execute_prompt_stream.side_effect = generate_prompt_events
|
84
|
+
|
85
|
+
# WHEN we run the node
|
86
|
+
node = ExamplePromptDeploymentNode()
|
87
|
+
events = list(node.run())
|
88
|
+
|
89
|
+
# THEN the node should have completed successfully
|
90
|
+
assert events[-1].value == "Great!"
|
91
|
+
|
92
|
+
# AND we should have invoked the Prompt Deployment with the expected inputs
|
93
|
+
call_kwargs = vellum_client.execute_prompt_stream.call_args.kwargs
|
94
|
+
assert call_kwargs["inputs"] == [
|
95
|
+
JsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
|
96
|
+
]
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
from uuid import UUID
|
2
3
|
from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Union, cast
|
3
4
|
|
@@ -11,11 +12,13 @@ from vellum import (
|
|
11
12
|
WorkflowRequestNumberInputRequest,
|
12
13
|
WorkflowRequestStringInputRequest,
|
13
14
|
)
|
15
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
14
16
|
from vellum.core import RequestOptions
|
15
17
|
from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
|
16
18
|
from vellum.workflows.context import get_parent_context
|
17
19
|
from vellum.workflows.errors import WorkflowErrorCode
|
18
20
|
from vellum.workflows.errors.types import workflow_event_error_to_workflow_error
|
21
|
+
from vellum.workflows.events.types import default_serializer
|
19
22
|
from vellum.workflows.exceptions import NodeException
|
20
23
|
from vellum.workflows.nodes.bases.base import BaseNode
|
21
24
|
from vellum.workflows.outputs.base import BaseOutput
|
@@ -66,11 +69,22 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
66
69
|
value=input_value,
|
67
70
|
)
|
68
71
|
)
|
69
|
-
elif isinstance(input_value, list) and all(
|
72
|
+
elif isinstance(input_value, list) and all(
|
73
|
+
isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
|
74
|
+
):
|
75
|
+
chat_history = [
|
76
|
+
(
|
77
|
+
message
|
78
|
+
if isinstance(message, ChatMessageRequest)
|
79
|
+
else ChatMessageRequest.model_validate(message.model_dump())
|
80
|
+
)
|
81
|
+
for message in input_value
|
82
|
+
if isinstance(message, (ChatMessage, ChatMessageRequest))
|
83
|
+
]
|
70
84
|
compiled_inputs.append(
|
71
85
|
WorkflowRequestChatHistoryInputRequest(
|
72
86
|
name=input_name,
|
73
|
-
value=
|
87
|
+
value=chat_history,
|
74
88
|
)
|
75
89
|
)
|
76
90
|
elif isinstance(input_value, dict):
|
@@ -88,9 +102,19 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
88
102
|
)
|
89
103
|
)
|
90
104
|
else:
|
91
|
-
|
92
|
-
|
93
|
-
|
105
|
+
try:
|
106
|
+
input_value = default_serializer(input_value)
|
107
|
+
except json.JSONDecodeError as e:
|
108
|
+
raise NodeException(
|
109
|
+
message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
|
110
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
111
|
+
)
|
112
|
+
|
113
|
+
compiled_inputs.append(
|
114
|
+
WorkflowRequestJsonInputRequest(
|
115
|
+
name=input_name,
|
116
|
+
value=input_value,
|
117
|
+
)
|
94
118
|
)
|
95
119
|
|
96
120
|
return compiled_inputs
|
File without changes
|
@@ -0,0 +1,131 @@
|
|
1
|
+
import pytest
|
2
|
+
from datetime import datetime
|
3
|
+
from uuid import uuid4
|
4
|
+
from typing import Any, Iterator, List
|
5
|
+
|
6
|
+
from vellum.client.types.chat_message import ChatMessage
|
7
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
8
|
+
from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
|
9
|
+
from vellum.client.types.workflow_output_string import WorkflowOutputString
|
10
|
+
from vellum.client.types.workflow_request_chat_history_input_request import WorkflowRequestChatHistoryInputRequest
|
11
|
+
from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
|
12
|
+
from vellum.client.types.workflow_result_event import WorkflowResultEvent
|
13
|
+
from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
|
14
|
+
from vellum.workflows.nodes.displayable.subworkflow_deployment_node.node import SubworkflowDeploymentNode
|
15
|
+
|
16
|
+
|
17
|
+
@pytest.mark.parametrize("ChatMessageClass", [ChatMessageRequest, ChatMessage])
|
18
|
+
def test_run_workflow__chat_history_input(vellum_client, ChatMessageClass):
|
19
|
+
"""Confirm that we can successfully invoke a Subworkflow Deployment Node that uses Chat History Inputs"""
|
20
|
+
|
21
|
+
# GIVEN a Subworkflow Deployment Node
|
22
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
23
|
+
deployment = "example_subworkflow_deployment"
|
24
|
+
subworkflow_inputs = {
|
25
|
+
"chat_history": [ChatMessageClass(role="USER", text="Hello, how are you?")],
|
26
|
+
}
|
27
|
+
|
28
|
+
# AND we know what the Subworkflow Deployment will respond with
|
29
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
30
|
+
execution_id = str(uuid4())
|
31
|
+
expected_events: List[WorkflowStreamEvent] = [
|
32
|
+
WorkflowExecutionWorkflowResultEvent(
|
33
|
+
execution_id=execution_id,
|
34
|
+
data=WorkflowResultEvent(
|
35
|
+
id=str(uuid4()),
|
36
|
+
state="INITIATED",
|
37
|
+
ts=datetime.now(),
|
38
|
+
),
|
39
|
+
),
|
40
|
+
WorkflowExecutionWorkflowResultEvent(
|
41
|
+
execution_id=execution_id,
|
42
|
+
data=WorkflowResultEvent(
|
43
|
+
id=str(uuid4()),
|
44
|
+
state="FULFILLED",
|
45
|
+
ts=datetime.now(),
|
46
|
+
outputs=[
|
47
|
+
WorkflowOutputString(
|
48
|
+
id=str(uuid4()),
|
49
|
+
name="greeting",
|
50
|
+
value="Great!",
|
51
|
+
)
|
52
|
+
],
|
53
|
+
),
|
54
|
+
),
|
55
|
+
]
|
56
|
+
yield from expected_events
|
57
|
+
|
58
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
59
|
+
|
60
|
+
# WHEN we run the node
|
61
|
+
node = ExampleSubworkflowDeploymentNode()
|
62
|
+
events = list(node.run())
|
63
|
+
|
64
|
+
# THEN the node should have completed successfully
|
65
|
+
assert events[-1].name == "greeting"
|
66
|
+
assert events[-1].value == "Great!"
|
67
|
+
|
68
|
+
# AND we should have invoked the Subworkflow Deployment with the expected inputs
|
69
|
+
call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
|
70
|
+
assert call_kwargs["inputs"] == [
|
71
|
+
WorkflowRequestChatHistoryInputRequest(
|
72
|
+
name="chat_history", value=[ChatMessageRequest(role="USER", text="Hello, how are you?")]
|
73
|
+
),
|
74
|
+
]
|
75
|
+
|
76
|
+
|
77
|
+
def test_run_workflow__any_array(vellum_client):
|
78
|
+
"""Confirm that we can successfully invoke a Subworkflow Deployment Node that uses any array input"""
|
79
|
+
|
80
|
+
# GIVEN a Subworkflow Deployment Node
|
81
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
82
|
+
deployment = "example_subworkflow_deployment"
|
83
|
+
subworkflow_inputs = {
|
84
|
+
"fruits": ["apple", "banana", "cherry"],
|
85
|
+
}
|
86
|
+
|
87
|
+
# AND we know what the Subworkflow Deployment will respond with
|
88
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
89
|
+
execution_id = str(uuid4())
|
90
|
+
expected_events: List[WorkflowStreamEvent] = [
|
91
|
+
WorkflowExecutionWorkflowResultEvent(
|
92
|
+
execution_id=execution_id,
|
93
|
+
data=WorkflowResultEvent(
|
94
|
+
id=str(uuid4()),
|
95
|
+
state="INITIATED",
|
96
|
+
ts=datetime.now(),
|
97
|
+
),
|
98
|
+
),
|
99
|
+
WorkflowExecutionWorkflowResultEvent(
|
100
|
+
execution_id=execution_id,
|
101
|
+
data=WorkflowResultEvent(
|
102
|
+
id=str(uuid4()),
|
103
|
+
state="FULFILLED",
|
104
|
+
ts=datetime.now(),
|
105
|
+
outputs=[
|
106
|
+
WorkflowOutputString(
|
107
|
+
id=str(uuid4()),
|
108
|
+
name="greeting",
|
109
|
+
value="Great!",
|
110
|
+
)
|
111
|
+
],
|
112
|
+
),
|
113
|
+
),
|
114
|
+
]
|
115
|
+
yield from expected_events
|
116
|
+
|
117
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
118
|
+
|
119
|
+
# WHEN we run the node
|
120
|
+
node = ExampleSubworkflowDeploymentNode()
|
121
|
+
events = list(node.run())
|
122
|
+
|
123
|
+
# THEN the node should have completed successfully
|
124
|
+
assert events[-1].name == "greeting"
|
125
|
+
assert events[-1].value == "Great!"
|
126
|
+
|
127
|
+
# AND we should have invoked the Subworkflow Deployment with the expected inputs
|
128
|
+
call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
|
129
|
+
assert call_kwargs["inputs"] == [
|
130
|
+
WorkflowRequestJsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
|
131
|
+
]
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from typing import Sequence, Union
|
2
|
+
|
3
|
+
from pydantic import ConfigDict
|
4
|
+
|
5
|
+
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
6
|
+
from vellum.workflows.descriptors.base import BaseDescriptor
|
7
|
+
from vellum.workflows.outputs.base import BaseOutputs
|
8
|
+
|
9
|
+
|
10
|
+
class MockNodeExecution(UniversalBaseModel):
|
11
|
+
when_condition: BaseDescriptor
|
12
|
+
then_outputs: BaseOutputs
|
13
|
+
|
14
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
15
|
+
|
16
|
+
|
17
|
+
MockNodeExecutionArg = Sequence[Union[BaseOutputs, MockNodeExecution]]
|
@@ -4,21 +4,7 @@ import logging
|
|
4
4
|
from queue import Empty, Queue
|
5
5
|
from threading import Event as ThreadingEvent, Thread
|
6
6
|
from uuid import UUID
|
7
|
-
from typing import
|
8
|
-
TYPE_CHECKING,
|
9
|
-
Any,
|
10
|
-
Dict,
|
11
|
-
Generic,
|
12
|
-
Iterable,
|
13
|
-
Iterator,
|
14
|
-
List,
|
15
|
-
Optional,
|
16
|
-
Sequence,
|
17
|
-
Set,
|
18
|
-
Tuple,
|
19
|
-
Type,
|
20
|
-
Union,
|
21
|
-
)
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
|
22
8
|
|
23
9
|
from vellum.workflows.constants import UNDEF
|
24
10
|
from vellum.workflows.context import execution_context, get_parent_context
|
@@ -59,12 +45,12 @@ from vellum.workflows.events.workflow import (
|
|
59
45
|
from vellum.workflows.exceptions import NodeException
|
60
46
|
from vellum.workflows.nodes.bases import BaseNode
|
61
47
|
from vellum.workflows.nodes.bases.base import NodeRunResponse
|
48
|
+
from vellum.workflows.nodes.mocks import MockNodeExecutionArg
|
62
49
|
from vellum.workflows.outputs import BaseOutputs
|
63
50
|
from vellum.workflows.outputs.base import BaseOutput
|
64
51
|
from vellum.workflows.ports.port import Port
|
65
52
|
from vellum.workflows.references import ExternalInputReference, OutputReference
|
66
53
|
from vellum.workflows.state.base import BaseState
|
67
|
-
from vellum.workflows.types.cycle_map import CycleMap
|
68
54
|
from vellum.workflows.types.generics import InputsType, OutputsType, StateType
|
69
55
|
|
70
56
|
if TYPE_CHECKING:
|
@@ -88,7 +74,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
88
74
|
entrypoint_nodes: Optional[RunFromNodeArg] = None,
|
89
75
|
external_inputs: Optional[ExternalInputsArg] = None,
|
90
76
|
cancel_signal: Optional[ThreadingEvent] = None,
|
91
|
-
node_output_mocks: Optional[
|
77
|
+
node_output_mocks: Optional[MockNodeExecutionArg] = None,
|
92
78
|
parent_context: Optional[ParentContext] = None,
|
93
79
|
max_concurrency: Optional[int] = None,
|
94
80
|
):
|
@@ -144,7 +130,6 @@ class WorkflowRunner(Generic[StateType]):
|
|
144
130
|
|
145
131
|
self._dependencies: Dict[Type[BaseNode], Set[Type[BaseNode]]] = defaultdict(set)
|
146
132
|
self._state_forks: Set[StateType] = {self._initial_state}
|
147
|
-
self._mocks_by_node_outputs_class = CycleMap(items=node_output_mocks or [], key_by=lambda mock: mock.__class__)
|
148
133
|
|
149
134
|
self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
|
150
135
|
self._cancel_signal = cancel_signal
|
@@ -156,6 +141,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
156
141
|
lambda s: self._snapshot_state(s),
|
157
142
|
)
|
158
143
|
self.workflow.context._register_event_queue(self._workflow_event_inner_queue)
|
144
|
+
self.workflow.context._register_node_output_mocks(node_output_mocks or [])
|
159
145
|
|
160
146
|
def _snapshot_state(self, state: StateType) -> StateType:
|
161
147
|
self._workflow_event_inner_queue.put(
|
@@ -201,11 +187,18 @@ class WorkflowRunner(Generic[StateType]):
|
|
201
187
|
parent=parent_context,
|
202
188
|
)
|
203
189
|
node_run_response: NodeRunResponse
|
204
|
-
|
190
|
+
was_mocked = False
|
191
|
+
mock_candidates = self.workflow.context.node_output_mocks_map.get(node.Outputs) or []
|
192
|
+
for mock_candidate in mock_candidates:
|
193
|
+
if mock_candidate.when_condition.resolve(node.state):
|
194
|
+
node_run_response = mock_candidate.then_outputs
|
195
|
+
was_mocked = True
|
196
|
+
break
|
197
|
+
|
198
|
+
if not was_mocked:
|
205
199
|
with execution_context(parent_context=updated_parent_context):
|
206
200
|
node_run_response = node.run()
|
207
|
-
|
208
|
-
node_run_response = self._mocks_by_node_outputs_class[node.Outputs]
|
201
|
+
|
209
202
|
ports = node.Ports()
|
210
203
|
if not isinstance(node_run_response, (BaseOutputs, Iterator)):
|
211
204
|
raise NodeException(
|
@@ -519,19 +512,6 @@ class WorkflowRunner(Generic[StateType]):
|
|
519
512
|
)
|
520
513
|
|
521
514
|
def _stream(self) -> None:
|
522
|
-
# TODO: We should likely handle this during initialization
|
523
|
-
# https://app.shortcut.com/vellum/story/4327
|
524
|
-
if not self._entrypoints:
|
525
|
-
self._workflow_event_outer_queue.put(
|
526
|
-
self._reject_workflow_event(
|
527
|
-
WorkflowError(
|
528
|
-
message="No entrypoints defined",
|
529
|
-
code=WorkflowErrorCode.INVALID_WORKFLOW,
|
530
|
-
)
|
531
|
-
)
|
532
|
-
)
|
533
|
-
return
|
534
|
-
|
535
515
|
for edge in self.workflow.get_edges():
|
536
516
|
self._dependencies[edge.to_node].add(edge.from_port.node_class)
|
537
517
|
|