vellum-ai 0.13.21__py3-none-any.whl → 0.13.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +12 -0
- vellum/client/__init__.py +170 -5
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +12 -0
- vellum/client/types/execute_api_request_bearer_token.py +6 -0
- vellum/client/types/execute_api_request_body.py +5 -0
- vellum/client/types/execute_api_request_headers_value.py +6 -0
- vellum/client/types/execute_api_response.py +24 -0
- vellum/client/types/method_enum.py +5 -0
- vellum/client/types/vellum_secret.py +19 -0
- vellum/plugins/pydantic.py +13 -1
- vellum/types/execute_api_request_bearer_token.py +3 -0
- vellum/types/execute_api_request_body.py +3 -0
- vellum/types/execute_api_request_headers_value.py +3 -0
- vellum/types/execute_api_response.py +3 -0
- vellum/types/method_enum.py +3 -0
- vellum/types/vellum_secret.py +3 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -0
- vellum/workflows/nodes/core/map_node/node.py +1 -0
- vellum/workflows/nodes/core/retry_node/node.py +1 -0
- vellum/workflows/nodes/core/try_node/node.py +11 -7
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +11 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +16 -10
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +49 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +93 -0
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +19 -4
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +128 -0
- vellum/workflows/nodes/mocks.py +17 -0
- vellum/workflows/runner/runner.py +14 -34
- vellum/workflows/state/context.py +29 -1
- vellum/workflows/workflows/base.py +9 -6
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.22.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.22.dist-info}/RECORD +43 -26
- vellum_cli/push.py +5 -6
- vellum_cli/tests/test_push.py +81 -1
- vellum_ee/workflows/display/types.py +1 -31
- vellum_ee/workflows/display/workflows/base_workflow_display.py +46 -2
- vellum_ee/workflows/tests/test_server.py +9 -0
- vellum/workflows/types/cycle_map.py +0 -34
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.22.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.22.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.22.dist-info}/entry_points.txt +0 -0
@@ -5,6 +5,8 @@ from typing import Any, Iterator, List
|
|
5
5
|
|
6
6
|
from vellum.client.core.api_error import ApiError
|
7
7
|
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
8
|
+
from vellum.client.types.chat_message import ChatMessage
|
9
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
8
10
|
from vellum.client.types.execute_prompt_event import ExecutePromptEvent
|
9
11
|
from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
|
10
12
|
from vellum.client.types.function_call import FunctionCall
|
@@ -12,6 +14,7 @@ from vellum.client.types.function_call_vellum_value import FunctionCallVellumVal
|
|
12
14
|
from vellum.client.types.function_definition import FunctionDefinition
|
13
15
|
from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
|
14
16
|
from vellum.client.types.prompt_output import PromptOutput
|
17
|
+
from vellum.client.types.prompt_request_chat_history_input import PromptRequestChatHistoryInput
|
15
18
|
from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
|
16
19
|
from vellum.client.types.string_vellum_value import StringVellumValue
|
17
20
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -181,3 +184,49 @@ def test_inline_prompt_node__api_error__invalid_inputs_node_exception(
|
|
181
184
|
# THEN the node raises the correct NodeException
|
182
185
|
assert e.value.code == expected_code
|
183
186
|
assert e.value.message == expected_message
|
187
|
+
|
188
|
+
|
189
|
+
def test_inline_prompt_node__chat_history_inputs(vellum_adhoc_prompt_client):
|
190
|
+
# GIVEN a prompt node with a chat history input
|
191
|
+
class MyNode(InlinePromptNode):
|
192
|
+
ml_model = "gpt-4o"
|
193
|
+
blocks = []
|
194
|
+
prompt_inputs = {
|
195
|
+
"chat_history": [ChatMessageRequest(role="USER", text="Hello, how are you?")],
|
196
|
+
}
|
197
|
+
|
198
|
+
# AND a known response from invoking an inline prompt
|
199
|
+
expected_outputs: List[PromptOutput] = [
|
200
|
+
StringVellumValue(value="Great!"),
|
201
|
+
]
|
202
|
+
|
203
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
204
|
+
execution_id = str(uuid4())
|
205
|
+
events: List[ExecutePromptEvent] = [
|
206
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
207
|
+
FulfilledExecutePromptEvent(
|
208
|
+
execution_id=execution_id,
|
209
|
+
outputs=expected_outputs,
|
210
|
+
),
|
211
|
+
]
|
212
|
+
yield from events
|
213
|
+
|
214
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
215
|
+
|
216
|
+
# WHEN the node is run
|
217
|
+
events = list(MyNode().run())
|
218
|
+
|
219
|
+
# THEN the prompt is executed with the correct inputs
|
220
|
+
assert events[-1].value == "Great!"
|
221
|
+
|
222
|
+
# AND the prompt is executed with the correct inputs
|
223
|
+
mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
|
224
|
+
assert mock_api.call_count == 1
|
225
|
+
assert mock_api.call_args.kwargs["input_values"] == [
|
226
|
+
PromptRequestChatHistoryInput(
|
227
|
+
key="chat_history",
|
228
|
+
type="CHAT_HISTORY",
|
229
|
+
value=[ChatMessage(role="USER", text="Hello, how are you?")],
|
230
|
+
),
|
231
|
+
]
|
232
|
+
assert mock_api.call_args.kwargs["input_variables"][0].type == "CHAT_HISTORY"
|
File without changes
|
@@ -0,0 +1,93 @@
|
|
1
|
+
from uuid import uuid4
|
2
|
+
from typing import Any, Iterator, List
|
3
|
+
|
4
|
+
from vellum.client.types.chat_history_input_request import ChatHistoryInputRequest
|
5
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
6
|
+
from vellum.client.types.execute_prompt_event import ExecutePromptEvent
|
7
|
+
from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
|
8
|
+
from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
|
9
|
+
from vellum.client.types.json_input_request import JsonInputRequest
|
10
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
11
|
+
from vellum.workflows.nodes.displayable.prompt_deployment_node.node import PromptDeploymentNode
|
12
|
+
|
13
|
+
|
14
|
+
def test_run_node__chat_history_input(vellum_client):
|
15
|
+
"""Confirm that we can successfully invoke a Prompt Deployment Node that uses Chat History Inputs"""
|
16
|
+
|
17
|
+
# GIVEN a Prompt Deployment Node
|
18
|
+
class ExamplePromptDeploymentNode(PromptDeploymentNode):
|
19
|
+
deployment = "example_prompt_deployment"
|
20
|
+
prompt_inputs = {
|
21
|
+
"chat_history": [ChatMessageRequest(role="USER", text="Hello, how are you?")],
|
22
|
+
}
|
23
|
+
|
24
|
+
# AND we know what the Prompt Deployment will respond with
|
25
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
26
|
+
execution_id = str(uuid4())
|
27
|
+
events: List[ExecutePromptEvent] = [
|
28
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
29
|
+
FulfilledExecutePromptEvent(
|
30
|
+
execution_id=execution_id,
|
31
|
+
outputs=[
|
32
|
+
StringVellumValue(value="Great!"),
|
33
|
+
],
|
34
|
+
),
|
35
|
+
]
|
36
|
+
yield from events
|
37
|
+
|
38
|
+
vellum_client.execute_prompt_stream.side_effect = generate_prompt_events
|
39
|
+
|
40
|
+
# WHEN we run the node
|
41
|
+
node = ExamplePromptDeploymentNode()
|
42
|
+
events = list(node.run())
|
43
|
+
|
44
|
+
# THEN the node should have completed successfully
|
45
|
+
assert events[-1].value == "Great!"
|
46
|
+
|
47
|
+
# AND we should have invoked the Prompt Deployment with the expected inputs
|
48
|
+
call_kwargs = vellum_client.execute_prompt_stream.call_args.kwargs
|
49
|
+
assert call_kwargs["inputs"] == [
|
50
|
+
ChatHistoryInputRequest(
|
51
|
+
name="chat_history", value=[ChatMessageRequest(role="USER", text="Hello, how are you?")]
|
52
|
+
),
|
53
|
+
]
|
54
|
+
|
55
|
+
|
56
|
+
def test_run_node__any_array_input(vellum_client):
|
57
|
+
"""Confirm that we can successfully invoke a Prompt Deployment Node that uses any array input"""
|
58
|
+
|
59
|
+
# GIVEN a Prompt Deployment Node
|
60
|
+
class ExamplePromptDeploymentNode(PromptDeploymentNode):
|
61
|
+
deployment = "example_prompt_deployment"
|
62
|
+
prompt_inputs = {
|
63
|
+
"fruits": ["apple", "banana", "cherry"],
|
64
|
+
}
|
65
|
+
|
66
|
+
# AND we know what the Prompt Deployment will respond with
|
67
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
68
|
+
execution_id = str(uuid4())
|
69
|
+
events: List[ExecutePromptEvent] = [
|
70
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
71
|
+
FulfilledExecutePromptEvent(
|
72
|
+
execution_id=execution_id,
|
73
|
+
outputs=[
|
74
|
+
StringVellumValue(value="Great!"),
|
75
|
+
],
|
76
|
+
),
|
77
|
+
]
|
78
|
+
yield from events
|
79
|
+
|
80
|
+
vellum_client.execute_prompt_stream.side_effect = generate_prompt_events
|
81
|
+
|
82
|
+
# WHEN we run the node
|
83
|
+
node = ExamplePromptDeploymentNode()
|
84
|
+
events = list(node.run())
|
85
|
+
|
86
|
+
# THEN the node should have completed successfully
|
87
|
+
assert events[-1].value == "Great!"
|
88
|
+
|
89
|
+
# AND we should have invoked the Prompt Deployment with the expected inputs
|
90
|
+
call_kwargs = vellum_client.execute_prompt_stream.call_args.kwargs
|
91
|
+
assert call_kwargs["inputs"] == [
|
92
|
+
JsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
|
93
|
+
]
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
from uuid import UUID
|
2
3
|
from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Union, cast
|
3
4
|
|
@@ -11,11 +12,13 @@ from vellum import (
|
|
11
12
|
WorkflowRequestNumberInputRequest,
|
12
13
|
WorkflowRequestStringInputRequest,
|
13
14
|
)
|
15
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
14
16
|
from vellum.core import RequestOptions
|
15
17
|
from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
|
16
18
|
from vellum.workflows.context import get_parent_context
|
17
19
|
from vellum.workflows.errors import WorkflowErrorCode
|
18
20
|
from vellum.workflows.errors.types import workflow_event_error_to_workflow_error
|
21
|
+
from vellum.workflows.events.types import default_serializer
|
19
22
|
from vellum.workflows.exceptions import NodeException
|
20
23
|
from vellum.workflows.nodes.bases.base import BaseNode
|
21
24
|
from vellum.workflows.outputs.base import BaseOutput
|
@@ -66,7 +69,9 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
66
69
|
value=input_value,
|
67
70
|
)
|
68
71
|
)
|
69
|
-
elif isinstance(input_value, list) and all(
|
72
|
+
elif isinstance(input_value, list) and all(
|
73
|
+
isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
|
74
|
+
):
|
70
75
|
compiled_inputs.append(
|
71
76
|
WorkflowRequestChatHistoryInputRequest(
|
72
77
|
name=input_name,
|
@@ -88,9 +93,19 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
88
93
|
)
|
89
94
|
)
|
90
95
|
else:
|
91
|
-
|
92
|
-
|
93
|
-
|
96
|
+
try:
|
97
|
+
input_value = default_serializer(input_value)
|
98
|
+
except json.JSONDecodeError as e:
|
99
|
+
raise NodeException(
|
100
|
+
message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
|
101
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
102
|
+
)
|
103
|
+
|
104
|
+
compiled_inputs.append(
|
105
|
+
WorkflowRequestJsonInputRequest(
|
106
|
+
name=input_name,
|
107
|
+
value=input_value,
|
108
|
+
)
|
94
109
|
)
|
95
110
|
|
96
111
|
return compiled_inputs
|
File without changes
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from uuid import uuid4
|
3
|
+
from typing import Any, Iterator, List
|
4
|
+
|
5
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
6
|
+
from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
|
7
|
+
from vellum.client.types.workflow_output_string import WorkflowOutputString
|
8
|
+
from vellum.client.types.workflow_request_chat_history_input_request import WorkflowRequestChatHistoryInputRequest
|
9
|
+
from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
|
10
|
+
from vellum.client.types.workflow_result_event import WorkflowResultEvent
|
11
|
+
from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
|
12
|
+
from vellum.workflows.nodes.displayable.subworkflow_deployment_node.node import SubworkflowDeploymentNode
|
13
|
+
|
14
|
+
|
15
|
+
def test_run_workflow__chat_history_input(vellum_client):
|
16
|
+
"""Confirm that we can successfully invoke a Subworkflow Deployment Node that uses Chat History Inputs"""
|
17
|
+
|
18
|
+
# GIVEN a Subworkflow Deployment Node
|
19
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
20
|
+
deployment = "example_subworkflow_deployment"
|
21
|
+
subworkflow_inputs = {
|
22
|
+
"chat_history": [ChatMessageRequest(role="USER", text="Hello, how are you?")],
|
23
|
+
}
|
24
|
+
|
25
|
+
# AND we know what the Subworkflow Deployment will respond with
|
26
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
27
|
+
execution_id = str(uuid4())
|
28
|
+
expected_events: List[WorkflowStreamEvent] = [
|
29
|
+
WorkflowExecutionWorkflowResultEvent(
|
30
|
+
execution_id=execution_id,
|
31
|
+
data=WorkflowResultEvent(
|
32
|
+
id=str(uuid4()),
|
33
|
+
state="INITIATED",
|
34
|
+
ts=datetime.now(),
|
35
|
+
),
|
36
|
+
),
|
37
|
+
WorkflowExecutionWorkflowResultEvent(
|
38
|
+
execution_id=execution_id,
|
39
|
+
data=WorkflowResultEvent(
|
40
|
+
id=str(uuid4()),
|
41
|
+
state="FULFILLED",
|
42
|
+
ts=datetime.now(),
|
43
|
+
outputs=[
|
44
|
+
WorkflowOutputString(
|
45
|
+
id=str(uuid4()),
|
46
|
+
name="greeting",
|
47
|
+
value="Great!",
|
48
|
+
)
|
49
|
+
],
|
50
|
+
),
|
51
|
+
),
|
52
|
+
]
|
53
|
+
yield from expected_events
|
54
|
+
|
55
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
56
|
+
|
57
|
+
# WHEN we run the node
|
58
|
+
node = ExampleSubworkflowDeploymentNode()
|
59
|
+
events = list(node.run())
|
60
|
+
|
61
|
+
# THEN the node should have completed successfully
|
62
|
+
assert events[-1].name == "greeting"
|
63
|
+
assert events[-1].value == "Great!"
|
64
|
+
|
65
|
+
# AND we should have invoked the Subworkflow Deployment with the expected inputs
|
66
|
+
call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
|
67
|
+
assert call_kwargs["inputs"] == [
|
68
|
+
WorkflowRequestChatHistoryInputRequest(
|
69
|
+
name="chat_history", value=[ChatMessageRequest(role="USER", text="Hello, how are you?")]
|
70
|
+
),
|
71
|
+
]
|
72
|
+
|
73
|
+
|
74
|
+
def test_run_workflow__any_array(vellum_client):
|
75
|
+
"""Confirm that we can successfully invoke a Subworkflow Deployment Node that uses any array input"""
|
76
|
+
|
77
|
+
# GIVEN a Subworkflow Deployment Node
|
78
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
79
|
+
deployment = "example_subworkflow_deployment"
|
80
|
+
subworkflow_inputs = {
|
81
|
+
"fruits": ["apple", "banana", "cherry"],
|
82
|
+
}
|
83
|
+
|
84
|
+
# AND we know what the Subworkflow Deployment will respond with
|
85
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
86
|
+
execution_id = str(uuid4())
|
87
|
+
expected_events: List[WorkflowStreamEvent] = [
|
88
|
+
WorkflowExecutionWorkflowResultEvent(
|
89
|
+
execution_id=execution_id,
|
90
|
+
data=WorkflowResultEvent(
|
91
|
+
id=str(uuid4()),
|
92
|
+
state="INITIATED",
|
93
|
+
ts=datetime.now(),
|
94
|
+
),
|
95
|
+
),
|
96
|
+
WorkflowExecutionWorkflowResultEvent(
|
97
|
+
execution_id=execution_id,
|
98
|
+
data=WorkflowResultEvent(
|
99
|
+
id=str(uuid4()),
|
100
|
+
state="FULFILLED",
|
101
|
+
ts=datetime.now(),
|
102
|
+
outputs=[
|
103
|
+
WorkflowOutputString(
|
104
|
+
id=str(uuid4()),
|
105
|
+
name="greeting",
|
106
|
+
value="Great!",
|
107
|
+
)
|
108
|
+
],
|
109
|
+
),
|
110
|
+
),
|
111
|
+
]
|
112
|
+
yield from expected_events
|
113
|
+
|
114
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
115
|
+
|
116
|
+
# WHEN we run the node
|
117
|
+
node = ExampleSubworkflowDeploymentNode()
|
118
|
+
events = list(node.run())
|
119
|
+
|
120
|
+
# THEN the node should have completed successfully
|
121
|
+
assert events[-1].name == "greeting"
|
122
|
+
assert events[-1].value == "Great!"
|
123
|
+
|
124
|
+
# AND we should have invoked the Subworkflow Deployment with the expected inputs
|
125
|
+
call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
|
126
|
+
assert call_kwargs["inputs"] == [
|
127
|
+
WorkflowRequestJsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
|
128
|
+
]
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from typing import Sequence, Union
|
2
|
+
|
3
|
+
from pydantic import ConfigDict
|
4
|
+
|
5
|
+
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
6
|
+
from vellum.workflows.descriptors.base import BaseDescriptor
|
7
|
+
from vellum.workflows.outputs.base import BaseOutputs
|
8
|
+
|
9
|
+
|
10
|
+
class MockNodeExecution(UniversalBaseModel):
|
11
|
+
when_condition: BaseDescriptor
|
12
|
+
then_outputs: BaseOutputs
|
13
|
+
|
14
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
15
|
+
|
16
|
+
|
17
|
+
MockNodeExecutionArg = Sequence[Union[BaseOutputs, MockNodeExecution]]
|
@@ -4,21 +4,7 @@ import logging
|
|
4
4
|
from queue import Empty, Queue
|
5
5
|
from threading import Event as ThreadingEvent, Thread
|
6
6
|
from uuid import UUID
|
7
|
-
from typing import
|
8
|
-
TYPE_CHECKING,
|
9
|
-
Any,
|
10
|
-
Dict,
|
11
|
-
Generic,
|
12
|
-
Iterable,
|
13
|
-
Iterator,
|
14
|
-
List,
|
15
|
-
Optional,
|
16
|
-
Sequence,
|
17
|
-
Set,
|
18
|
-
Tuple,
|
19
|
-
Type,
|
20
|
-
Union,
|
21
|
-
)
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
|
22
8
|
|
23
9
|
from vellum.workflows.constants import UNDEF
|
24
10
|
from vellum.workflows.context import execution_context, get_parent_context
|
@@ -59,12 +45,12 @@ from vellum.workflows.events.workflow import (
|
|
59
45
|
from vellum.workflows.exceptions import NodeException
|
60
46
|
from vellum.workflows.nodes.bases import BaseNode
|
61
47
|
from vellum.workflows.nodes.bases.base import NodeRunResponse
|
48
|
+
from vellum.workflows.nodes.mocks import MockNodeExecutionArg
|
62
49
|
from vellum.workflows.outputs import BaseOutputs
|
63
50
|
from vellum.workflows.outputs.base import BaseOutput
|
64
51
|
from vellum.workflows.ports.port import Port
|
65
52
|
from vellum.workflows.references import ExternalInputReference, OutputReference
|
66
53
|
from vellum.workflows.state.base import BaseState
|
67
|
-
from vellum.workflows.types.cycle_map import CycleMap
|
68
54
|
from vellum.workflows.types.generics import InputsType, OutputsType, StateType
|
69
55
|
|
70
56
|
if TYPE_CHECKING:
|
@@ -88,7 +74,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
88
74
|
entrypoint_nodes: Optional[RunFromNodeArg] = None,
|
89
75
|
external_inputs: Optional[ExternalInputsArg] = None,
|
90
76
|
cancel_signal: Optional[ThreadingEvent] = None,
|
91
|
-
node_output_mocks: Optional[
|
77
|
+
node_output_mocks: Optional[MockNodeExecutionArg] = None,
|
92
78
|
parent_context: Optional[ParentContext] = None,
|
93
79
|
max_concurrency: Optional[int] = None,
|
94
80
|
):
|
@@ -144,7 +130,6 @@ class WorkflowRunner(Generic[StateType]):
|
|
144
130
|
|
145
131
|
self._dependencies: Dict[Type[BaseNode], Set[Type[BaseNode]]] = defaultdict(set)
|
146
132
|
self._state_forks: Set[StateType] = {self._initial_state}
|
147
|
-
self._mocks_by_node_outputs_class = CycleMap(items=node_output_mocks or [], key_by=lambda mock: mock.__class__)
|
148
133
|
|
149
134
|
self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
|
150
135
|
self._cancel_signal = cancel_signal
|
@@ -156,6 +141,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
156
141
|
lambda s: self._snapshot_state(s),
|
157
142
|
)
|
158
143
|
self.workflow.context._register_event_queue(self._workflow_event_inner_queue)
|
144
|
+
self.workflow.context._register_node_output_mocks(node_output_mocks or [])
|
159
145
|
|
160
146
|
def _snapshot_state(self, state: StateType) -> StateType:
|
161
147
|
self._workflow_event_inner_queue.put(
|
@@ -201,11 +187,18 @@ class WorkflowRunner(Generic[StateType]):
|
|
201
187
|
parent=parent_context,
|
202
188
|
)
|
203
189
|
node_run_response: NodeRunResponse
|
204
|
-
|
190
|
+
was_mocked = False
|
191
|
+
mock_candidates = self.workflow.context.node_output_mocks_map.get(node.Outputs) or []
|
192
|
+
for mock_candidate in mock_candidates:
|
193
|
+
if mock_candidate.when_condition.resolve(node.state):
|
194
|
+
node_run_response = mock_candidate.then_outputs
|
195
|
+
was_mocked = True
|
196
|
+
break
|
197
|
+
|
198
|
+
if not was_mocked:
|
205
199
|
with execution_context(parent_context=updated_parent_context):
|
206
200
|
node_run_response = node.run()
|
207
|
-
|
208
|
-
node_run_response = self._mocks_by_node_outputs_class[node.Outputs]
|
201
|
+
|
209
202
|
ports = node.Ports()
|
210
203
|
if not isinstance(node_run_response, (BaseOutputs, Iterator)):
|
211
204
|
raise NodeException(
|
@@ -519,19 +512,6 @@ class WorkflowRunner(Generic[StateType]):
|
|
519
512
|
)
|
520
513
|
|
521
514
|
def _stream(self) -> None:
|
522
|
-
# TODO: We should likely handle this during initialization
|
523
|
-
# https://app.shortcut.com/vellum/story/4327
|
524
|
-
if not self._entrypoints:
|
525
|
-
self._workflow_event_outer_queue.put(
|
526
|
-
self._reject_workflow_event(
|
527
|
-
WorkflowError(
|
528
|
-
message="No entrypoints defined",
|
529
|
-
code=WorkflowErrorCode.INVALID_WORKFLOW,
|
530
|
-
)
|
531
|
-
)
|
532
|
-
)
|
533
|
-
return
|
534
|
-
|
535
515
|
for edge in self.workflow.get_edges():
|
536
516
|
self._dependencies[edge.to_node].add(edge.from_port.node_class)
|
537
517
|
|
@@ -1,9 +1,12 @@
|
|
1
1
|
from functools import cached_property
|
2
2
|
from queue import Queue
|
3
|
-
from typing import TYPE_CHECKING, Optional
|
3
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
4
4
|
|
5
5
|
from vellum import Vellum
|
6
6
|
from vellum.workflows.events.types import ParentContext
|
7
|
+
from vellum.workflows.nodes.mocks import MockNodeExecution, MockNodeExecutionArg
|
8
|
+
from vellum.workflows.outputs.base import BaseOutputs
|
9
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
7
10
|
from vellum.workflows.vellum_client import create_vellum_client
|
8
11
|
|
9
12
|
if TYPE_CHECKING:
|
@@ -20,6 +23,7 @@ class WorkflowContext:
|
|
20
23
|
self._vellum_client = vellum_client
|
21
24
|
self._parent_context = parent_context
|
22
25
|
self._event_queue: Optional[Queue["WorkflowEvent"]] = None
|
26
|
+
self._node_output_mocks_map: Dict[Type[BaseOutputs], List[MockNodeExecution]] = {}
|
23
27
|
|
24
28
|
@cached_property
|
25
29
|
def vellum_client(self) -> Vellum:
|
@@ -34,9 +38,33 @@ class WorkflowContext:
|
|
34
38
|
return self._parent_context
|
35
39
|
return None
|
36
40
|
|
41
|
+
@cached_property
|
42
|
+
def node_output_mocks_map(self) -> Dict[Type[BaseOutputs], List[MockNodeExecution]]:
|
43
|
+
return self._node_output_mocks_map
|
44
|
+
|
37
45
|
def _emit_subworkflow_event(self, event: "WorkflowEvent") -> None:
|
38
46
|
if self._event_queue:
|
39
47
|
self._event_queue.put(event)
|
40
48
|
|
41
49
|
def _register_event_queue(self, event_queue: Queue["WorkflowEvent"]) -> None:
|
42
50
|
self._event_queue = event_queue
|
51
|
+
|
52
|
+
def _register_node_output_mocks(self, node_output_mocks: MockNodeExecutionArg) -> None:
|
53
|
+
for mock in node_output_mocks:
|
54
|
+
if isinstance(mock, MockNodeExecution):
|
55
|
+
key = mock.then_outputs.__class__
|
56
|
+
value = mock
|
57
|
+
else:
|
58
|
+
key = mock.__class__
|
59
|
+
value = MockNodeExecution(
|
60
|
+
when_condition=ConstantValueReference(True),
|
61
|
+
then_outputs=mock,
|
62
|
+
)
|
63
|
+
|
64
|
+
if key not in self._node_output_mocks_map:
|
65
|
+
self._node_output_mocks_map[key] = [value]
|
66
|
+
else:
|
67
|
+
self._node_output_mocks_map[key].append(value)
|
68
|
+
|
69
|
+
def _get_all_node_output_mocks(self) -> List[MockNodeExecution]:
|
70
|
+
return [mock for mocks in self._node_output_mocks_map.values() for mock in mocks]
|
@@ -61,6 +61,7 @@ from vellum.workflows.events.workflow import (
|
|
61
61
|
from vellum.workflows.graph import Graph
|
62
62
|
from vellum.workflows.inputs.base import BaseInputs
|
63
63
|
from vellum.workflows.nodes.bases import BaseNode
|
64
|
+
from vellum.workflows.nodes.mocks import MockNodeExecutionArg
|
64
65
|
from vellum.workflows.outputs import BaseOutputs
|
65
66
|
from vellum.workflows.resolvers.base import BaseWorkflowResolver
|
66
67
|
from vellum.workflows.runner import WorkflowRunner
|
@@ -187,7 +188,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
187
188
|
entrypoint_nodes: Optional[RunFromNodeArg] = None,
|
188
189
|
external_inputs: Optional[ExternalInputsArg] = None,
|
189
190
|
cancel_signal: Optional[ThreadingEvent] = None,
|
190
|
-
node_output_mocks: Optional[
|
191
|
+
node_output_mocks: Optional[MockNodeExecutionArg] = None,
|
191
192
|
max_concurrency: Optional[int] = None,
|
192
193
|
) -> TerminalWorkflowEvent:
|
193
194
|
"""
|
@@ -214,8 +215,9 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
214
215
|
cancel_signal: Optional[ThreadingEvent] = None
|
215
216
|
A threading event that can be used to cancel the Workflow Execution.
|
216
217
|
|
217
|
-
node_output_mocks: Optional[
|
218
|
-
A list of Outputs to mock for Nodes during Workflow Execution.
|
218
|
+
node_output_mocks: Optional[MockNodeExecutionArg] = None
|
219
|
+
A list of Outputs to mock for Nodes during Workflow Execution. Each mock can include a `when_condition`
|
220
|
+
that must be met for the mock to be used.
|
219
221
|
|
220
222
|
max_concurrency: Optional[int] = None
|
221
223
|
The max number of concurrent threads to run the Workflow with. If not provided, the Workflow will run
|
@@ -295,7 +297,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
295
297
|
entrypoint_nodes: Optional[RunFromNodeArg] = None,
|
296
298
|
external_inputs: Optional[ExternalInputsArg] = None,
|
297
299
|
cancel_signal: Optional[ThreadingEvent] = None,
|
298
|
-
node_output_mocks: Optional[
|
300
|
+
node_output_mocks: Optional[MockNodeExecutionArg] = None,
|
299
301
|
max_concurrency: Optional[int] = None,
|
300
302
|
) -> WorkflowEventStream:
|
301
303
|
"""
|
@@ -323,8 +325,9 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
323
325
|
cancel_signal: Optional[ThreadingEvent] = None
|
324
326
|
A threading event that can be used to cancel the Workflow Execution.
|
325
327
|
|
326
|
-
node_output_mocks: Optional[
|
327
|
-
A list of Outputs to mock for Nodes during Workflow Execution.
|
328
|
+
node_output_mocks: Optional[MockNodeExecutionArg] = None
|
329
|
+
A list of Outputs to mock for Nodes during Workflow Execution. Each mock can include a `when_condition`
|
330
|
+
that must be met for the mock to be used.
|
328
331
|
|
329
332
|
max_concurrency: Optional[int] = None
|
330
333
|
The max number of concurrent threads to run the Workflow with. If not provided, the Workflow will run
|