vellum-ai 0.13.21__py3-none-any.whl → 0.13.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. vellum/__init__.py +12 -0
  2. vellum/client/__init__.py +170 -5
  3. vellum/client/core/client_wrapper.py +1 -1
  4. vellum/client/types/__init__.py +12 -0
  5. vellum/client/types/execute_api_request_bearer_token.py +6 -0
  6. vellum/client/types/execute_api_request_body.py +5 -0
  7. vellum/client/types/execute_api_request_headers_value.py +6 -0
  8. vellum/client/types/execute_api_response.py +24 -0
  9. vellum/client/types/method_enum.py +5 -0
  10. vellum/client/types/vellum_secret.py +19 -0
  11. vellum/plugins/pydantic.py +13 -1
  12. vellum/types/execute_api_request_bearer_token.py +3 -0
  13. vellum/types/execute_api_request_body.py +3 -0
  14. vellum/types/execute_api_request_headers_value.py +3 -0
  15. vellum/types/execute_api_response.py +3 -0
  16. vellum/types/method_enum.py +3 -0
  17. vellum/types/vellum_secret.py +3 -0
  18. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -0
  19. vellum/workflows/nodes/core/map_node/node.py +1 -0
  20. vellum/workflows/nodes/core/retry_node/node.py +1 -0
  21. vellum/workflows/nodes/core/try_node/node.py +11 -7
  22. vellum/workflows/nodes/displayable/api_node/node.py +12 -3
  23. vellum/workflows/nodes/displayable/api_node/tests/__init__.py +0 -0
  24. vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +34 -0
  25. vellum/workflows/nodes/displayable/bases/api_node/node.py +25 -4
  26. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +11 -3
  27. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +27 -12
  28. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +33 -0
  29. vellum/workflows/nodes/displayable/code_execution_node/utils.py +14 -12
  30. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +49 -0
  31. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/__init__.py +0 -0
  32. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +96 -0
  33. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +29 -5
  34. vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/__init__.py +0 -0
  35. vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +131 -0
  36. vellum/workflows/nodes/mocks.py +17 -0
  37. vellum/workflows/runner/runner.py +14 -34
  38. vellum/workflows/state/context.py +29 -1
  39. vellum/workflows/workflows/base.py +9 -6
  40. {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/METADATA +1 -1
  41. {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/RECORD +50 -31
  42. vellum_cli/push.py +5 -6
  43. vellum_cli/tests/test_push.py +81 -1
  44. vellum_ee/workflows/display/types.py +1 -31
  45. vellum_ee/workflows/display/vellum.py +1 -1
  46. vellum_ee/workflows/display/workflows/base_workflow_display.py +46 -2
  47. vellum_ee/workflows/tests/test_server.py +9 -0
  48. vellum/workflows/types/cycle_map.py +0 -34
  49. {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/LICENSE +0 -0
  50. {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/WHEEL +0 -0
  51. {vellum_ai-0.13.21.dist-info → vellum_ai-0.13.23.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from uuid import uuid4
3
- from typing import Callable, ClassVar, Generic, Iterator, List, Optional, Tuple, Union, cast
3
+ from typing import Callable, ClassVar, Generic, Iterator, List, Optional, Tuple, Union
4
4
 
5
5
  from vellum import (
6
6
  AdHocExecutePromptEvent,
@@ -16,6 +16,7 @@ from vellum import (
16
16
  VellumVariable,
17
17
  )
18
18
  from vellum.client import RequestOptions
19
+ from vellum.client.types.chat_message_request import ChatMessageRequest
19
20
  from vellum.workflows.constants import OMIT
20
21
  from vellum.workflows.context import get_parent_context
21
22
  from vellum.workflows.errors import WorkflowErrorCode
@@ -108,7 +109,14 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
108
109
  value=input_value,
109
110
  )
110
111
  )
111
- elif isinstance(input_value, list) and all(isinstance(message, ChatMessage) for message in input_value):
112
+ elif isinstance(input_value, list) and all(
113
+ isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
114
+ ):
115
+ chat_history = [
116
+ message if isinstance(message, ChatMessage) else ChatMessage.model_validate(message.model_dump())
117
+ for message in input_value
118
+ if isinstance(message, (ChatMessage, ChatMessageRequest))
119
+ ]
112
120
  input_variables.append(
113
121
  VellumVariable(
114
122
  # TODO: Determine whether or not we actually need an id here and if we do,
@@ -122,7 +130,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
122
130
  input_values.append(
123
131
  PromptRequestChatHistoryInput(
124
132
  key=input_name,
125
- value=cast(List[ChatMessage], input_value),
133
+ value=chat_history,
126
134
  )
127
135
  )
128
136
  else:
@@ -1,5 +1,6 @@
1
+ import json
1
2
  from uuid import UUID
2
- from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Sequence, Union, cast
3
+ from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Sequence, Union
3
4
 
4
5
  from vellum import (
5
6
  ChatHistoryInputRequest,
@@ -12,9 +13,11 @@ from vellum import (
12
13
  StringInputRequest,
13
14
  )
14
15
  from vellum.client import RequestOptions
16
+ from vellum.client.types.chat_message_request import ChatMessageRequest
15
17
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
16
18
  from vellum.workflows.context import get_parent_context
17
19
  from vellum.workflows.errors import WorkflowErrorCode
20
+ from vellum.workflows.events.types import default_serializer
18
21
  from vellum.workflows.exceptions import NodeException
19
22
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
20
23
  from vellum.workflows.types import MergeBehavior
@@ -89,26 +92,38 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
89
92
  value=input_value,
90
93
  )
91
94
  )
92
- elif isinstance(input_value, list) and all(isinstance(message, ChatMessage) for message in input_value):
95
+ elif isinstance(input_value, list) and all(
96
+ isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
97
+ ):
98
+ chat_history = [
99
+ (
100
+ message
101
+ if isinstance(message, ChatMessageRequest)
102
+ else ChatMessageRequest.model_validate(message.model_dump())
103
+ )
104
+ for message in input_value
105
+ if isinstance(message, (ChatMessage, ChatMessageRequest))
106
+ ]
93
107
  compiled_inputs.append(
94
108
  ChatHistoryInputRequest(
95
109
  name=input_name,
96
- value=cast(List[ChatMessage], input_value),
110
+ value=chat_history,
97
111
  )
98
112
  )
99
- elif isinstance(input_value, dict):
100
- # Note: We may want to fail early here if we know that input_value is not
101
- # JSON serializable.
113
+ else:
114
+ try:
115
+ input_value = default_serializer(input_value)
116
+ except json.JSONDecodeError as e:
117
+ raise NodeException(
118
+ message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
119
+ code=WorkflowErrorCode.INVALID_INPUTS,
120
+ )
121
+
102
122
  compiled_inputs.append(
103
123
  JsonInputRequest(
104
124
  name=input_name,
105
- value=cast(Dict[str, Any], input_value),
125
+ value=input_value,
106
126
  )
107
127
  )
108
- else:
109
- raise NodeException(
110
- message=f"Unrecognized input type for input '{input_name}': {input_value.__class__}",
111
- code=WorkflowErrorCode.INVALID_INPUTS,
112
- )
113
128
 
114
129
  return compiled_inputs
@@ -1,5 +1,6 @@
1
1
  import pytest
2
2
  import os
3
+ from typing import Any
3
4
 
4
5
  from vellum import CodeExecutorResponse, NumberVellumValue, StringInput
5
6
  from vellum.client.types.code_execution_package import CodeExecutionPackage
@@ -413,3 +414,35 @@ name
413
414
  Field required [type=missing, input_value={'n': 'hello', 'a': {}}, input_type=dict]\
414
415
  """
415
416
  )
417
+
418
+
419
+ def test_run_workflow__run_inline__valid_dict_to_pydantic_any_type():
420
+ """Confirm that CodeExecutionNodes can convert a dict to a Pydantic model during inline execution."""
421
+
422
+ # GIVEN a node that subclasses CodeExecutionNode that returns a dict matching Any
423
+ class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, Any]):
424
+ code = """\
425
+ def main(word: str) -> dict:
426
+ return {
427
+ "name": "word",
428
+ "arguments": {},
429
+ }
430
+ """
431
+ runtime = "PYTHON_3_11_6"
432
+
433
+ code_inputs = {
434
+ "word": "hello",
435
+ }
436
+
437
+ # WHEN we run the node
438
+ node = ExampleCodeExecutionNode()
439
+ outputs = node.run()
440
+
441
+ # THEN the node should have produced the outputs we expect
442
+ assert outputs == {
443
+ "result": {
444
+ "name": "word",
445
+ "arguments": {},
446
+ },
447
+ "log": "",
448
+ }
@@ -91,19 +91,21 @@ __arg__out = main({", ".join(run_args)})
91
91
  logs = log_buffer.getvalue()
92
92
  result = exec_globals["__arg__out"]
93
93
 
94
- if issubclass(output_type, BaseModel) and not isinstance(result, output_type):
95
- try:
96
- result = output_type.model_validate(result)
97
- except ValidationError as e:
94
+ if output_type != Any:
95
+ if issubclass(output_type, BaseModel) and not isinstance(result, output_type):
96
+ try:
97
+ result = output_type.model_validate(result)
98
+ except ValidationError as e:
99
+ raise NodeException(
100
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
101
+ message=re.sub(r"\s+For further information visit [^\s]+", "", str(e)),
102
+ ) from e
103
+
104
+ if not isinstance(result, output_type):
98
105
  raise NodeException(
99
106
  code=WorkflowErrorCode.INVALID_OUTPUTS,
100
- message=re.sub(r"\s+For further information visit [^\s]+", "", str(e)),
101
- ) from e
102
-
103
- if not isinstance(result, output_type):
104
- raise NodeException(
105
- code=WorkflowErrorCode.INVALID_OUTPUTS,
106
- message=f"Expected an output of type '{output_type.__name__}', but received '{result.__class__.__name__}'",
107
- )
107
+ message=f"Expected an output of type '{output_type.__name__}',"
108
+ f" but received '{result.__class__.__name__}'",
109
+ )
108
110
 
109
111
  return logs, result
@@ -5,6 +5,8 @@ from typing import Any, Iterator, List
5
5
 
6
6
  from vellum.client.core.api_error import ApiError
7
7
  from vellum.client.core.pydantic_utilities import UniversalBaseModel
8
+ from vellum.client.types.chat_message import ChatMessage
9
+ from vellum.client.types.chat_message_request import ChatMessageRequest
8
10
  from vellum.client.types.execute_prompt_event import ExecutePromptEvent
9
11
  from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
10
12
  from vellum.client.types.function_call import FunctionCall
@@ -12,6 +14,7 @@ from vellum.client.types.function_call_vellum_value import FunctionCallVellumVal
12
14
  from vellum.client.types.function_definition import FunctionDefinition
13
15
  from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
14
16
  from vellum.client.types.prompt_output import PromptOutput
17
+ from vellum.client.types.prompt_request_chat_history_input import PromptRequestChatHistoryInput
15
18
  from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
16
19
  from vellum.client.types.string_vellum_value import StringVellumValue
17
20
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -181,3 +184,49 @@ def test_inline_prompt_node__api_error__invalid_inputs_node_exception(
181
184
  # THEN the node raises the correct NodeException
182
185
  assert e.value.code == expected_code
183
186
  assert e.value.message == expected_message
187
+
188
+
189
+ def test_inline_prompt_node__chat_history_inputs(vellum_adhoc_prompt_client):
190
+ # GIVEN a prompt node with a chat history input
191
+ class MyNode(InlinePromptNode):
192
+ ml_model = "gpt-4o"
193
+ blocks = []
194
+ prompt_inputs = {
195
+ "chat_history": [ChatMessageRequest(role="USER", text="Hello, how are you?")],
196
+ }
197
+
198
+ # AND a known response from invoking an inline prompt
199
+ expected_outputs: List[PromptOutput] = [
200
+ StringVellumValue(value="Great!"),
201
+ ]
202
+
203
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
204
+ execution_id = str(uuid4())
205
+ events: List[ExecutePromptEvent] = [
206
+ InitiatedExecutePromptEvent(execution_id=execution_id),
207
+ FulfilledExecutePromptEvent(
208
+ execution_id=execution_id,
209
+ outputs=expected_outputs,
210
+ ),
211
+ ]
212
+ yield from events
213
+
214
+ vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
215
+
216
+ # WHEN the node is run
217
+ events = list(MyNode().run())
218
+
219
+ # THEN the prompt is executed with the correct inputs
220
+ assert events[-1].value == "Great!"
221
+
222
+ # AND the prompt is executed with the correct inputs
223
+ mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
224
+ assert mock_api.call_count == 1
225
+ assert mock_api.call_args.kwargs["input_values"] == [
226
+ PromptRequestChatHistoryInput(
227
+ key="chat_history",
228
+ type="CHAT_HISTORY",
229
+ value=[ChatMessage(role="USER", text="Hello, how are you?")],
230
+ ),
231
+ ]
232
+ assert mock_api.call_args.kwargs["input_variables"][0].type == "CHAT_HISTORY"
@@ -0,0 +1,96 @@
1
+ import pytest
2
+ from uuid import uuid4
3
+ from typing import Any, Iterator, List
4
+
5
+ from vellum.client.types.chat_history_input_request import ChatHistoryInputRequest
6
+ from vellum.client.types.chat_message import ChatMessage
7
+ from vellum.client.types.chat_message_request import ChatMessageRequest
8
+ from vellum.client.types.execute_prompt_event import ExecutePromptEvent
9
+ from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
10
+ from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
11
+ from vellum.client.types.json_input_request import JsonInputRequest
12
+ from vellum.client.types.string_vellum_value import StringVellumValue
13
+ from vellum.workflows.nodes.displayable.prompt_deployment_node.node import PromptDeploymentNode
14
+
15
+
16
+ @pytest.mark.parametrize("ChatMessageClass", [ChatMessageRequest, ChatMessage])
17
+ def test_run_node__chat_history_input(vellum_client, ChatMessageClass):
18
+ """Confirm that we can successfully invoke a Prompt Deployment Node that uses Chat History Inputs"""
19
+
20
+ # GIVEN a Prompt Deployment Node
21
+ class ExamplePromptDeploymentNode(PromptDeploymentNode):
22
+ deployment = "example_prompt_deployment"
23
+ prompt_inputs = {
24
+ "chat_history": [ChatMessageClass(role="USER", text="Hello, how are you?")],
25
+ }
26
+
27
+ # AND we know what the Prompt Deployment will respond with
28
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
29
+ execution_id = str(uuid4())
30
+ events: List[ExecutePromptEvent] = [
31
+ InitiatedExecutePromptEvent(execution_id=execution_id),
32
+ FulfilledExecutePromptEvent(
33
+ execution_id=execution_id,
34
+ outputs=[
35
+ StringVellumValue(value="Great!"),
36
+ ],
37
+ ),
38
+ ]
39
+ yield from events
40
+
41
+ vellum_client.execute_prompt_stream.side_effect = generate_prompt_events
42
+
43
+ # WHEN we run the node
44
+ node = ExamplePromptDeploymentNode()
45
+ events = list(node.run())
46
+
47
+ # THEN the node should have completed successfully
48
+ assert events[-1].value == "Great!"
49
+
50
+ # AND we should have invoked the Prompt Deployment with the expected inputs
51
+ call_kwargs = vellum_client.execute_prompt_stream.call_args.kwargs
52
+ assert call_kwargs["inputs"] == [
53
+ ChatHistoryInputRequest(
54
+ name="chat_history", value=[ChatMessageRequest(role="USER", text="Hello, how are you?")]
55
+ ),
56
+ ]
57
+
58
+
59
+ def test_run_node__any_array_input(vellum_client):
60
+ """Confirm that we can successfully invoke a Prompt Deployment Node that uses any array input"""
61
+
62
+ # GIVEN a Prompt Deployment Node
63
+ class ExamplePromptDeploymentNode(PromptDeploymentNode):
64
+ deployment = "example_prompt_deployment"
65
+ prompt_inputs = {
66
+ "fruits": ["apple", "banana", "cherry"],
67
+ }
68
+
69
+ # AND we know what the Prompt Deployment will respond with
70
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
71
+ execution_id = str(uuid4())
72
+ events: List[ExecutePromptEvent] = [
73
+ InitiatedExecutePromptEvent(execution_id=execution_id),
74
+ FulfilledExecutePromptEvent(
75
+ execution_id=execution_id,
76
+ outputs=[
77
+ StringVellumValue(value="Great!"),
78
+ ],
79
+ ),
80
+ ]
81
+ yield from events
82
+
83
+ vellum_client.execute_prompt_stream.side_effect = generate_prompt_events
84
+
85
+ # WHEN we run the node
86
+ node = ExamplePromptDeploymentNode()
87
+ events = list(node.run())
88
+
89
+ # THEN the node should have completed successfully
90
+ assert events[-1].value == "Great!"
91
+
92
+ # AND we should have invoked the Prompt Deployment with the expected inputs
93
+ call_kwargs = vellum_client.execute_prompt_stream.call_args.kwargs
94
+ assert call_kwargs["inputs"] == [
95
+ JsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
96
+ ]
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from uuid import UUID
2
3
  from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Union, cast
3
4
 
@@ -11,11 +12,13 @@ from vellum import (
11
12
  WorkflowRequestNumberInputRequest,
12
13
  WorkflowRequestStringInputRequest,
13
14
  )
15
+ from vellum.client.types.chat_message_request import ChatMessageRequest
14
16
  from vellum.core import RequestOptions
15
17
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
16
18
  from vellum.workflows.context import get_parent_context
17
19
  from vellum.workflows.errors import WorkflowErrorCode
18
20
  from vellum.workflows.errors.types import workflow_event_error_to_workflow_error
21
+ from vellum.workflows.events.types import default_serializer
19
22
  from vellum.workflows.exceptions import NodeException
20
23
  from vellum.workflows.nodes.bases.base import BaseNode
21
24
  from vellum.workflows.outputs.base import BaseOutput
@@ -66,11 +69,22 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
66
69
  value=input_value,
67
70
  )
68
71
  )
69
- elif isinstance(input_value, list) and all(isinstance(message, ChatMessage) for message in input_value):
72
+ elif isinstance(input_value, list) and all(
73
+ isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
74
+ ):
75
+ chat_history = [
76
+ (
77
+ message
78
+ if isinstance(message, ChatMessageRequest)
79
+ else ChatMessageRequest.model_validate(message.model_dump())
80
+ )
81
+ for message in input_value
82
+ if isinstance(message, (ChatMessage, ChatMessageRequest))
83
+ ]
70
84
  compiled_inputs.append(
71
85
  WorkflowRequestChatHistoryInputRequest(
72
86
  name=input_name,
73
- value=cast(List[ChatMessage], input_value),
87
+ value=chat_history,
74
88
  )
75
89
  )
76
90
  elif isinstance(input_value, dict):
@@ -88,9 +102,19 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
88
102
  )
89
103
  )
90
104
  else:
91
- raise NodeException(
92
- message=f"Unrecognized input type for input '{input_name}'",
93
- code=WorkflowErrorCode.INVALID_INPUTS,
105
+ try:
106
+ input_value = default_serializer(input_value)
107
+ except json.JSONDecodeError as e:
108
+ raise NodeException(
109
+ message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
110
+ code=WorkflowErrorCode.INVALID_INPUTS,
111
+ )
112
+
113
+ compiled_inputs.append(
114
+ WorkflowRequestJsonInputRequest(
115
+ name=input_name,
116
+ value=input_value,
117
+ )
94
118
  )
95
119
 
96
120
  return compiled_inputs
@@ -0,0 +1,131 @@
1
+ import pytest
2
+ from datetime import datetime
3
+ from uuid import uuid4
4
+ from typing import Any, Iterator, List
5
+
6
+ from vellum.client.types.chat_message import ChatMessage
7
+ from vellum.client.types.chat_message_request import ChatMessageRequest
8
+ from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
9
+ from vellum.client.types.workflow_output_string import WorkflowOutputString
10
+ from vellum.client.types.workflow_request_chat_history_input_request import WorkflowRequestChatHistoryInputRequest
11
+ from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
12
+ from vellum.client.types.workflow_result_event import WorkflowResultEvent
13
+ from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
14
+ from vellum.workflows.nodes.displayable.subworkflow_deployment_node.node import SubworkflowDeploymentNode
15
+
16
+
17
+ @pytest.mark.parametrize("ChatMessageClass", [ChatMessageRequest, ChatMessage])
18
+ def test_run_workflow__chat_history_input(vellum_client, ChatMessageClass):
19
+ """Confirm that we can successfully invoke a Subworkflow Deployment Node that uses Chat History Inputs"""
20
+
21
+ # GIVEN a Subworkflow Deployment Node
22
+ class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
23
+ deployment = "example_subworkflow_deployment"
24
+ subworkflow_inputs = {
25
+ "chat_history": [ChatMessageClass(role="USER", text="Hello, how are you?")],
26
+ }
27
+
28
+ # AND we know what the Subworkflow Deployment will respond with
29
+ def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
30
+ execution_id = str(uuid4())
31
+ expected_events: List[WorkflowStreamEvent] = [
32
+ WorkflowExecutionWorkflowResultEvent(
33
+ execution_id=execution_id,
34
+ data=WorkflowResultEvent(
35
+ id=str(uuid4()),
36
+ state="INITIATED",
37
+ ts=datetime.now(),
38
+ ),
39
+ ),
40
+ WorkflowExecutionWorkflowResultEvent(
41
+ execution_id=execution_id,
42
+ data=WorkflowResultEvent(
43
+ id=str(uuid4()),
44
+ state="FULFILLED",
45
+ ts=datetime.now(),
46
+ outputs=[
47
+ WorkflowOutputString(
48
+ id=str(uuid4()),
49
+ name="greeting",
50
+ value="Great!",
51
+ )
52
+ ],
53
+ ),
54
+ ),
55
+ ]
56
+ yield from expected_events
57
+
58
+ vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
59
+
60
+ # WHEN we run the node
61
+ node = ExampleSubworkflowDeploymentNode()
62
+ events = list(node.run())
63
+
64
+ # THEN the node should have completed successfully
65
+ assert events[-1].name == "greeting"
66
+ assert events[-1].value == "Great!"
67
+
68
+ # AND we should have invoked the Subworkflow Deployment with the expected inputs
69
+ call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
70
+ assert call_kwargs["inputs"] == [
71
+ WorkflowRequestChatHistoryInputRequest(
72
+ name="chat_history", value=[ChatMessageRequest(role="USER", text="Hello, how are you?")]
73
+ ),
74
+ ]
75
+
76
+
77
+ def test_run_workflow__any_array(vellum_client):
78
+ """Confirm that we can successfully invoke a Subworkflow Deployment Node that uses any array input"""
79
+
80
+ # GIVEN a Subworkflow Deployment Node
81
+ class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
82
+ deployment = "example_subworkflow_deployment"
83
+ subworkflow_inputs = {
84
+ "fruits": ["apple", "banana", "cherry"],
85
+ }
86
+
87
+ # AND we know what the Subworkflow Deployment will respond with
88
+ def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
89
+ execution_id = str(uuid4())
90
+ expected_events: List[WorkflowStreamEvent] = [
91
+ WorkflowExecutionWorkflowResultEvent(
92
+ execution_id=execution_id,
93
+ data=WorkflowResultEvent(
94
+ id=str(uuid4()),
95
+ state="INITIATED",
96
+ ts=datetime.now(),
97
+ ),
98
+ ),
99
+ WorkflowExecutionWorkflowResultEvent(
100
+ execution_id=execution_id,
101
+ data=WorkflowResultEvent(
102
+ id=str(uuid4()),
103
+ state="FULFILLED",
104
+ ts=datetime.now(),
105
+ outputs=[
106
+ WorkflowOutputString(
107
+ id=str(uuid4()),
108
+ name="greeting",
109
+ value="Great!",
110
+ )
111
+ ],
112
+ ),
113
+ ),
114
+ ]
115
+ yield from expected_events
116
+
117
+ vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
118
+
119
+ # WHEN we run the node
120
+ node = ExampleSubworkflowDeploymentNode()
121
+ events = list(node.run())
122
+
123
+ # THEN the node should have completed successfully
124
+ assert events[-1].name == "greeting"
125
+ assert events[-1].value == "Great!"
126
+
127
+ # AND we should have invoked the Subworkflow Deployment with the expected inputs
128
+ call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
129
+ assert call_kwargs["inputs"] == [
130
+ WorkflowRequestJsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
131
+ ]
@@ -0,0 +1,17 @@
1
+ from typing import Sequence, Union
2
+
3
+ from pydantic import ConfigDict
4
+
5
+ from vellum.client.core.pydantic_utilities import UniversalBaseModel
6
+ from vellum.workflows.descriptors.base import BaseDescriptor
7
+ from vellum.workflows.outputs.base import BaseOutputs
8
+
9
+
10
+ class MockNodeExecution(UniversalBaseModel):
11
+ when_condition: BaseDescriptor
12
+ then_outputs: BaseOutputs
13
+
14
+ model_config = ConfigDict(arbitrary_types_allowed=True)
15
+
16
+
17
+ MockNodeExecutionArg = Sequence[Union[BaseOutputs, MockNodeExecution]]
@@ -4,21 +4,7 @@ import logging
4
4
  from queue import Empty, Queue
5
5
  from threading import Event as ThreadingEvent, Thread
6
6
  from uuid import UUID
7
- from typing import (
8
- TYPE_CHECKING,
9
- Any,
10
- Dict,
11
- Generic,
12
- Iterable,
13
- Iterator,
14
- List,
15
- Optional,
16
- Sequence,
17
- Set,
18
- Tuple,
19
- Type,
20
- Union,
21
- )
7
+ from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
22
8
 
23
9
  from vellum.workflows.constants import UNDEF
24
10
  from vellum.workflows.context import execution_context, get_parent_context
@@ -59,12 +45,12 @@ from vellum.workflows.events.workflow import (
59
45
  from vellum.workflows.exceptions import NodeException
60
46
  from vellum.workflows.nodes.bases import BaseNode
61
47
  from vellum.workflows.nodes.bases.base import NodeRunResponse
48
+ from vellum.workflows.nodes.mocks import MockNodeExecutionArg
62
49
  from vellum.workflows.outputs import BaseOutputs
63
50
  from vellum.workflows.outputs.base import BaseOutput
64
51
  from vellum.workflows.ports.port import Port
65
52
  from vellum.workflows.references import ExternalInputReference, OutputReference
66
53
  from vellum.workflows.state.base import BaseState
67
- from vellum.workflows.types.cycle_map import CycleMap
68
54
  from vellum.workflows.types.generics import InputsType, OutputsType, StateType
69
55
 
70
56
  if TYPE_CHECKING:
@@ -88,7 +74,7 @@ class WorkflowRunner(Generic[StateType]):
88
74
  entrypoint_nodes: Optional[RunFromNodeArg] = None,
89
75
  external_inputs: Optional[ExternalInputsArg] = None,
90
76
  cancel_signal: Optional[ThreadingEvent] = None,
91
- node_output_mocks: Optional[List[BaseOutputs]] = None,
77
+ node_output_mocks: Optional[MockNodeExecutionArg] = None,
92
78
  parent_context: Optional[ParentContext] = None,
93
79
  max_concurrency: Optional[int] = None,
94
80
  ):
@@ -144,7 +130,6 @@ class WorkflowRunner(Generic[StateType]):
144
130
 
145
131
  self._dependencies: Dict[Type[BaseNode], Set[Type[BaseNode]]] = defaultdict(set)
146
132
  self._state_forks: Set[StateType] = {self._initial_state}
147
- self._mocks_by_node_outputs_class = CycleMap(items=node_output_mocks or [], key_by=lambda mock: mock.__class__)
148
133
 
149
134
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
150
135
  self._cancel_signal = cancel_signal
@@ -156,6 +141,7 @@ class WorkflowRunner(Generic[StateType]):
156
141
  lambda s: self._snapshot_state(s),
157
142
  )
158
143
  self.workflow.context._register_event_queue(self._workflow_event_inner_queue)
144
+ self.workflow.context._register_node_output_mocks(node_output_mocks or [])
159
145
 
160
146
  def _snapshot_state(self, state: StateType) -> StateType:
161
147
  self._workflow_event_inner_queue.put(
@@ -201,11 +187,18 @@ class WorkflowRunner(Generic[StateType]):
201
187
  parent=parent_context,
202
188
  )
203
189
  node_run_response: NodeRunResponse
204
- if node.Outputs not in self._mocks_by_node_outputs_class:
190
+ was_mocked = False
191
+ mock_candidates = self.workflow.context.node_output_mocks_map.get(node.Outputs) or []
192
+ for mock_candidate in mock_candidates:
193
+ if mock_candidate.when_condition.resolve(node.state):
194
+ node_run_response = mock_candidate.then_outputs
195
+ was_mocked = True
196
+ break
197
+
198
+ if not was_mocked:
205
199
  with execution_context(parent_context=updated_parent_context):
206
200
  node_run_response = node.run()
207
- else:
208
- node_run_response = self._mocks_by_node_outputs_class[node.Outputs]
201
+
209
202
  ports = node.Ports()
210
203
  if not isinstance(node_run_response, (BaseOutputs, Iterator)):
211
204
  raise NodeException(
@@ -519,19 +512,6 @@ class WorkflowRunner(Generic[StateType]):
519
512
  )
520
513
 
521
514
  def _stream(self) -> None:
522
- # TODO: We should likely handle this during initialization
523
- # https://app.shortcut.com/vellum/story/4327
524
- if not self._entrypoints:
525
- self._workflow_event_outer_queue.put(
526
- self._reject_workflow_event(
527
- WorkflowError(
528
- message="No entrypoints defined",
529
- code=WorkflowErrorCode.INVALID_WORKFLOW,
530
- )
531
- )
532
- )
533
- return
534
-
535
515
  for edge in self.workflow.get_edges():
536
516
  self._dependencies[edge.to_node].add(edge.from_port.node_class)
537
517