vellum-ai 0.14.11__py3-none-any.whl → 0.14.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/descriptors/base.py +3 -0
  3. vellum/workflows/events/workflow.py +23 -0
  4. vellum/workflows/inputs/base.py +26 -18
  5. vellum/workflows/inputs/tests/test_inputs.py +1 -1
  6. vellum/workflows/nodes/bases/base.py +7 -0
  7. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +7 -0
  8. vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +32 -0
  9. vellum/workflows/nodes/core/map_node/node.py +28 -7
  10. vellum/workflows/nodes/core/map_node/tests/test_node.py +31 -0
  11. vellum/workflows/nodes/core/try_node/node.py +7 -0
  12. vellum/workflows/nodes/core/try_node/tests/test_node.py +32 -0
  13. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +5 -4
  14. vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +111 -0
  15. vellum/workflows/nodes/mocks.py +229 -2
  16. vellum/workflows/nodes/tests/__init__.py +0 -0
  17. vellum/workflows/nodes/tests/test_mocks.py +207 -0
  18. vellum/workflows/outputs/base.py +1 -1
  19. {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/METADATA +2 -2
  20. {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/RECORD +28 -26
  21. vellum_ee/workflows/display/nodes/base_node_display.py +20 -4
  22. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -0
  23. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +24 -1
  24. vellum_ee/workflows/display/workflows/base_workflow_display.py +2 -2
  25. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +20 -0
  26. {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/LICENSE +0 -0
  27. {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/WHEEL +0 -0
  28. {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.11",
21
+ "X-Fern-SDK-Version": "0.14.13",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -71,6 +71,9 @@ class BaseDescriptor(Generic[_T]):
71
71
  def __hash__(self) -> int:
72
72
  return hash(self._name)
73
73
 
74
+ def __repr__(self) -> str:
75
+ return self._name
76
+
74
77
  @overload
75
78
  def __get__(self, instance: "BaseNode", owner: Type["BaseNode"]) -> _T: ...
76
79
 
@@ -1,5 +1,6 @@
1
1
  from uuid import UUID
2
2
  from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, Iterable, Literal, Optional, Type, Union
3
+ from typing_extensions import TypeGuard
3
4
 
4
5
  from pydantic import field_serializer
5
6
 
@@ -182,3 +183,25 @@ WorkflowEvent = Union[
182
183
  ]
183
184
 
184
185
  WorkflowEventStream = Generator[WorkflowEvent, None, None]
186
+
187
+ WorkflowExecutionEvent = Union[
188
+ WorkflowExecutionInitiatedEvent,
189
+ WorkflowExecutionStreamingEvent,
190
+ WorkflowExecutionRejectedEvent,
191
+ WorkflowExecutionPausedEvent,
192
+ WorkflowExecutionResumedEvent,
193
+ WorkflowExecutionFulfilledEvent,
194
+ WorkflowExecutionSnapshottedEvent,
195
+ ]
196
+
197
+
198
+ def is_workflow_event(event: WorkflowEvent) -> TypeGuard[WorkflowExecutionEvent]:
199
+ return (
200
+ event.name == "workflow.execution.initiated"
201
+ or event.name == "workflow.execution.fulfilled"
202
+ or event.name == "workflow.execution.streaming"
203
+ or event.name == "workflow.execution.snapshotted"
204
+ or event.name == "workflow.execution.paused"
205
+ or event.name == "workflow.execution.resumed"
206
+ or event.name == "workflow.execution.rejected"
207
+ )
@@ -42,38 +42,46 @@ class BaseInputs(metaclass=_BaseInputsMeta):
42
42
  __parent_class__: Type = type(None)
43
43
 
44
44
  def __init__(self, **kwargs: Any) -> None:
45
+ """
46
+ Initialize BaseInputs with provided keyword arguments.
47
+
48
+ Validation logic:
49
+ 1. Ensures all required fields (non-Optional types) either:
50
+ - Have a value provided in kwargs, or
51
+ - Have a default value defined in the class
52
+ 2. Validates that no None values are provided for required fields
53
+ 3. Sets all provided values as attributes on the instance
54
+
55
+ Args:
56
+ **kwargs: Keyword arguments corresponding to the class's type annotations
57
+
58
+ Raises:
59
+ WorkflowInitializationException: If a required field is missing or None
60
+ """
45
61
  for name, field_type in self.__class__.__annotations__.items():
46
- if name not in kwargs and name not in vars(self.__class__):
62
+ # Get the value (either from kwargs or class default)
63
+ value = kwargs.get(name)
64
+ has_default = name in vars(self.__class__)
65
+
66
+ if value is None and not has_default:
67
+ # Check if field_type allows None
47
68
  origin = get_origin(field_type)
48
69
  args = get_args(field_type)
49
70
  if not (origin is Union and type(None) in args):
50
71
  raise WorkflowInitializationException(
51
- message="Required input variables should have defined value",
72
+ message=f"Required input variables {name} should have defined value",
52
73
  code=WorkflowErrorCode.INVALID_INPUTS,
53
74
  )
54
75
 
55
- for name, value in kwargs.items():
56
- field_type = self.__class__.__annotations__.get(name)
57
- if field_type:
58
- self._validate_input(value, field_type)
59
- setattr(self, name, value)
76
+ # If value provided in kwargs, set it on the instance
77
+ if name in kwargs:
78
+ setattr(self, name, value)
60
79
 
61
80
  def __iter__(self) -> Iterator[Tuple[InputReference, Any]]:
62
81
  for input_descriptor in self.__class__:
63
82
  if hasattr(self, input_descriptor.name):
64
83
  yield (input_descriptor, getattr(self, input_descriptor.name))
65
84
 
66
- def _validate_input(self, value: Any, field_type: Any) -> None:
67
- if value is None:
68
- # Check if field_type is Optional
69
- origin = get_origin(field_type)
70
- args = get_args(field_type)
71
- if not (origin is Union and type(None) in args):
72
- raise WorkflowInitializationException(
73
- message="Required input variables should have defined value",
74
- code=WorkflowErrorCode.INVALID_INPUTS,
75
- )
76
-
77
85
  @classmethod
78
86
  def __get_pydantic_core_schema__(
79
87
  cls, source_type: Type[Any], handler: GetCoreSchemaHandler
@@ -34,7 +34,7 @@ def test_base_inputs_empty_value():
34
34
 
35
35
  # THEN it should raise a NodeException with the correct error message and code
36
36
  assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
37
- assert "Required input variables should have defined value" in str(exc_info.value)
37
+ assert "Required input variables required_string should have defined value" == str(exc_info.value)
38
38
 
39
39
 
40
40
  def test_base_inputs_with_default():
@@ -19,6 +19,7 @@ from vellum.workflows.ports.port import Port
19
19
  from vellum.workflows.references import ExternalInputReference
20
20
  from vellum.workflows.references.execution_count import ExecutionCountReference
21
21
  from vellum.workflows.references.node import NodeReference
22
+ from vellum.workflows.references.output import OutputReference
22
23
  from vellum.workflows.state.base import BaseState
23
24
  from vellum.workflows.state.context import WorkflowContext
24
25
  from vellum.workflows.types.core import MergeBehavior
@@ -118,6 +119,11 @@ class BaseNodeMeta(type):
118
119
  node_class.Trigger.node_class = node_class
119
120
  node_class.ExternalInputs.__parent_class__ = node_class
120
121
  node_class.__id__ = uuid4_from_hash(node_class.__qualname__)
122
+ node_class.__output_ids__ = {
123
+ ref.name: uuid4_from_hash(f"{node_class.__id__}|{ref.name}")
124
+ for ref in node_class.Outputs
125
+ if isinstance(ref, OutputReference)
126
+ }
121
127
  return node_class
122
128
 
123
129
  @property
@@ -236,6 +242,7 @@ NodeRunResponse = Union[BaseOutputs, Iterator[BaseOutput]]
236
242
 
237
243
  class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
238
244
  __id__: UUID = uuid4_from_hash(__qualname__)
245
+ __output_ids__: Dict[str, UUID] = {}
239
246
  state: StateType
240
247
  _context: WorkflowContext
241
248
  _inputs: MappingProxyType[NodeReference, Any]
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, Option
3
3
  from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.context import execution_context, get_parent_context
5
5
  from vellum.workflows.errors.types import WorkflowErrorCode
6
+ from vellum.workflows.events.workflow import is_workflow_event
6
7
  from vellum.workflows.exceptions import NodeException
7
8
  from vellum.workflows.inputs.base import BaseInputs
8
9
  from vellum.workflows.nodes.bases.base import BaseNode, BaseNodeMeta
@@ -86,6 +87,12 @@ class InlineSubworkflowNode(
86
87
 
87
88
  for event in subworkflow_stream:
88
89
  self._context._emit_subworkflow_event(event)
90
+
91
+ if not is_workflow_event(event):
92
+ continue
93
+ if event.workflow_definition != self.subworkflow:
94
+ continue
95
+
89
96
  if event.name == "workflow.execution.streaming":
90
97
  if event.output.is_fulfilled:
91
98
  fulfilled_output_names.add(event.output.name)
@@ -3,6 +3,7 @@ import pytest
3
3
  from vellum.workflows.inputs.base import BaseInputs
4
4
  from vellum.workflows.nodes.bases.base import BaseNode
5
5
  from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
6
+ from vellum.workflows.nodes.core.try_node.node import TryNode
6
7
  from vellum.workflows.outputs.base import BaseOutput
7
8
  from vellum.workflows.state.base import BaseState
8
9
  from vellum.workflows.workflows.base import BaseWorkflow
@@ -55,3 +56,34 @@ def test_inline_subworkflow_node__support_inputs_as_attributes():
55
56
  assert events == [
56
57
  BaseOutput(name="out", value="bar"),
57
58
  ]
59
+
60
+
61
+ def test_inline_subworkflow_node__nested_try():
62
+ """
63
+ Ensure that the nested try node doesn't affect the subworkflow node's outputs
64
+ """
65
+
66
+ # GIVEN a nested try node
67
+ @TryNode.wrap()
68
+ class InnerNode(BaseNode):
69
+ class Outputs:
70
+ foo = "hello"
71
+
72
+ # AND a subworkflow
73
+ class Subworkflow(BaseWorkflow):
74
+ graph = InnerNode
75
+
76
+ class Outputs(BaseWorkflow.Outputs):
77
+ bar = InnerNode.Outputs.foo
78
+
79
+ # AND an outer try node referencing that subworkflow
80
+ class OuterNode(InlineSubworkflowNode):
81
+ subworkflow = Subworkflow
82
+
83
+ # WHEN we run the try node
84
+ stream = OuterNode().run()
85
+ events = list(stream)
86
+
87
+ # THEN we only have the outer node's outputs
88
+ valid_events = [e for e in events if e.name == "bar"]
89
+ assert len(valid_events) == len(events)
@@ -1,4 +1,5 @@
1
1
  from collections import defaultdict
2
+ import logging
2
3
  from queue import Empty, Queue
3
4
  from threading import Thread
4
5
  from typing import (
@@ -19,6 +20,7 @@ from typing import (
19
20
  from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
20
21
  from vellum.workflows.descriptors.base import BaseDescriptor
21
22
  from vellum.workflows.errors.types import WorkflowErrorCode
23
+ from vellum.workflows.events.workflow import is_workflow_event
22
24
  from vellum.workflows.exceptions import NodeException
23
25
  from vellum.workflows.inputs.base import BaseInputs
24
26
  from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
@@ -33,6 +35,8 @@ from vellum.workflows.workflows.event_filters import all_workflow_event_filter
33
35
  if TYPE_CHECKING:
34
36
  from vellum.workflows.events.workflow import WorkflowEvent
35
37
 
38
+ logger = logging.getLogger(__name__)
39
+
36
40
  MapNodeItemType = TypeVar("MapNodeItemType")
37
41
 
38
42
 
@@ -104,19 +108,36 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
104
108
  subworkflow_event = map_node_event[1]
105
109
  self._context._emit_subworkflow_event(subworkflow_event)
106
110
 
111
+ if not is_workflow_event(subworkflow_event):
112
+ continue
113
+
114
+ if subworkflow_event.workflow_definition != self.subworkflow:
115
+ continue
116
+
107
117
  if subworkflow_event.name == "workflow.execution.initiated":
108
118
  for output_name in mapped_items.keys():
109
119
  yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
110
120
 
111
121
  elif subworkflow_event.name == "workflow.execution.fulfilled":
112
- workflow_output_vars = vars(subworkflow_event.outputs)
113
-
114
- for output_name in workflow_output_vars:
115
- output_mapped_items = mapped_items[output_name]
116
- output_mapped_items[index] = workflow_output_vars[output_name]
122
+ for output_reference, output_value in subworkflow_event.outputs:
123
+ if not isinstance(output_reference, OutputReference):
124
+ logger.error(
125
+ "Invalid key to map node's subworkflow event outputs",
126
+ extra={"output_reference_type": type(output_reference)},
127
+ )
128
+ continue
129
+
130
+ output_mapped_items = mapped_items[output_reference.name]
131
+ if index < 0 or index >= len(output_mapped_items):
132
+ logger.error(
133
+ "Invalid map node index", extra={"index": index, "output_name": output_reference.name}
134
+ )
135
+ continue
136
+
137
+ output_mapped_items[index] = output_value
117
138
  yield BaseOutput(
118
- name=output_name,
119
- delta=(output_mapped_items[index], index, "FULFILLED"),
139
+ name=output_reference.name,
140
+ delta=(output_value, index, "FULFILLED"),
120
141
  )
121
142
 
122
143
  fulfilled_iterations[index] = True
@@ -3,8 +3,10 @@ import time
3
3
  from vellum.workflows.inputs.base import BaseInputs
4
4
  from vellum.workflows.nodes.bases import BaseNode
5
5
  from vellum.workflows.nodes.core.map_node.node import MapNode
6
+ from vellum.workflows.nodes.core.try_node.node import TryNode
6
7
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
7
8
  from vellum.workflows.state.base import BaseState, StateMeta
9
+ from vellum.workflows.workflows.base import BaseWorkflow
8
10
 
9
11
 
10
12
  def test_map_node__use_parent_inputs_and_state():
@@ -85,3 +87,32 @@ def test_map_node__empty_list():
85
87
  # THEN the node should return an empty output
86
88
  fulfilled_output = outputs[-1]
87
89
  assert fulfilled_output == BaseOutput(name="value", value=[])
90
+
91
+
92
+ def test_map_node__inner_try():
93
+ # GIVEN a try wrapped node
94
+ @TryNode.wrap()
95
+ class InnerNode(BaseNode):
96
+ class Outputs(BaseNode.Outputs):
97
+ foo: str
98
+
99
+ # AND a workflow using that node
100
+ class SimpleMapNodeWorkflow(BaseWorkflow[MapNode.SubworkflowInputs, BaseState]):
101
+ graph = InnerNode
102
+
103
+ class Outputs(BaseWorkflow.Outputs):
104
+ final_output = InnerNode.Outputs.foo
105
+
106
+ # AND a map node referencing that workflow
107
+ class SimpleMapNode(MapNode):
108
+ items = ["hello", "world"]
109
+ subworkflow = SimpleMapNodeWorkflow
110
+ max_concurrency = 4
111
+
112
+ # WHEN we run the workflow
113
+ stream = SimpleMapNode().run()
114
+ outputs = list(stream)
115
+
116
+ # THEN the workflow should succeed
117
+ assert outputs[-1].name == "final_output"
118
+ assert len(outputs[-1].value) == 2
@@ -2,6 +2,7 @@ from typing import Callable, Generic, Iterator, Optional, Set, Type
2
2
 
3
3
  from vellum.workflows.context import execution_context, get_parent_context
4
4
  from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
5
+ from vellum.workflows.events.workflow import is_workflow_event
5
6
  from vellum.workflows.exceptions import NodeException
6
7
  from vellum.workflows.nodes.bases import BaseNode
7
8
  from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
@@ -47,6 +48,12 @@ class TryNode(BaseAdornmentNode[StateType], Generic[StateType]):
47
48
  if exception:
48
49
  continue
49
50
 
51
+ if not is_workflow_event(event):
52
+ continue
53
+
54
+ if event.workflow_definition != self.subworkflow:
55
+ continue
56
+
50
57
  if event.name == "workflow.execution.streaming":
51
58
  if event.output.is_fulfilled:
52
59
  fulfilled_output_names.add(event.output.name)
@@ -10,6 +10,7 @@ from vellum.workflows.outputs import BaseOutputs
10
10
  from vellum.workflows.outputs.base import BaseOutput
11
11
  from vellum.workflows.state.base import BaseState, StateMeta
12
12
  from vellum.workflows.state.context import WorkflowContext
13
+ from vellum.workflows.workflows.base import BaseWorkflow
13
14
 
14
15
 
15
16
  def test_try_node__on_error_code__successfully_caught():
@@ -126,3 +127,34 @@ def test_try_node__resolved_inputs():
126
127
  foo = State.counter
127
128
 
128
129
  assert MyNode.foo.types == (float,)
130
+
131
+
132
+ def test_try_node__nested_try():
133
+ """
134
+ Ensure that the nested try node doesn't affect the outer try node's outputs
135
+ """
136
+
137
+ # GIVEN a nested try node
138
+ @TryNode.wrap()
139
+ class InnerNode(BaseNode):
140
+ class Outputs:
141
+ foo = "hello"
142
+
143
+ # AND a subworkflow
144
+ class Subworkflow(BaseWorkflow):
145
+ graph = InnerNode
146
+
147
+ class Outputs(BaseWorkflow.Outputs):
148
+ bar = InnerNode.Outputs.foo
149
+
150
+ # AND an outer try node referencing that subworkflow
151
+ class OuterNode(TryNode):
152
+ subworkflow = Subworkflow
153
+
154
+ # WHEN we run the try node
155
+ stream = OuterNode().run()
156
+ events = list(stream)
157
+
158
+ # THEN we only have the outer node's outputs
159
+ valid_events = [e for e in events if e.name == "bar"]
160
+ assert len(valid_events) == len(events)
@@ -70,8 +70,10 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
70
70
  value=input_value,
71
71
  )
72
72
  )
73
- elif isinstance(input_value, list) and all(
74
- isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value
73
+ elif (
74
+ isinstance(input_value, list)
75
+ and len(input_value) > 0
76
+ and all(isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value)
75
77
  ):
76
78
  chat_history = [
77
79
  (
@@ -95,7 +97,7 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
95
97
  value=cast(Dict[str, Any], input_value),
96
98
  )
97
99
  )
98
- elif isinstance(input_value, float):
100
+ elif isinstance(input_value, (int, float)):
99
101
  compiled_inputs.append(
100
102
  WorkflowRequestNumberInputRequest(
101
103
  name=input_name,
@@ -110,7 +112,6 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
110
112
  message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
111
113
  code=WorkflowErrorCode.INVALID_INPUTS,
112
114
  )
113
-
114
115
  compiled_inputs.append(
115
116
  WorkflowRequestJsonInputRequest(
116
117
  name=input_name,
@@ -10,6 +10,7 @@ from vellum.client.types.workflow_execution_workflow_result_event import Workflo
10
10
  from vellum.client.types.workflow_output_string import WorkflowOutputString
11
11
  from vellum.client.types.workflow_request_chat_history_input_request import WorkflowRequestChatHistoryInputRequest
12
12
  from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
13
+ from vellum.client.types.workflow_request_number_input_request import WorkflowRequestNumberInputRequest
13
14
  from vellum.client.types.workflow_result_event import WorkflowResultEvent
14
15
  from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
15
16
  from vellum.workflows.errors import WorkflowErrorCode
@@ -134,6 +135,116 @@ def test_run_workflow__any_array(vellum_client):
134
135
  ]
135
136
 
136
137
 
138
+ def test_run_workflow__empty_array(vellum_client):
139
+ # GIVEN a Subworkflow Deployment Node
140
+ class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
141
+ deployment = "example_subworkflow_deployment"
142
+ subworkflow_inputs = {
143
+ "fruits": [],
144
+ }
145
+
146
+ # AND we know what the Subworkflow Deployment will respond with
147
+ def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
148
+ execution_id = str(uuid4())
149
+ expected_events: List[WorkflowStreamEvent] = [
150
+ WorkflowExecutionWorkflowResultEvent(
151
+ execution_id=execution_id,
152
+ data=WorkflowResultEvent(
153
+ id=str(uuid4()),
154
+ state="INITIATED",
155
+ ts=datetime.now(),
156
+ ),
157
+ ),
158
+ WorkflowExecutionWorkflowResultEvent(
159
+ execution_id=execution_id,
160
+ data=WorkflowResultEvent(
161
+ id=str(uuid4()),
162
+ state="FULFILLED",
163
+ ts=datetime.now(),
164
+ outputs=[
165
+ WorkflowOutputString(
166
+ id=str(uuid4()),
167
+ name="greeting",
168
+ value="Great!",
169
+ )
170
+ ],
171
+ ),
172
+ ),
173
+ ]
174
+ yield from expected_events
175
+
176
+ vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
177
+
178
+ # WHEN we run the node
179
+ node = ExampleSubworkflowDeploymentNode()
180
+ events = list(node.run())
181
+
182
+ # THEN the node should have completed successfully
183
+ assert events[-1].name == "greeting"
184
+ assert events[-1].value == "Great!"
185
+
186
+ # AND we should have invoked the Subworkflow Deployment with the expected inputs
187
+ call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
188
+ assert call_kwargs["inputs"] == [
189
+ WorkflowRequestJsonInputRequest(name="fruits", value=[]),
190
+ ]
191
+
192
+
193
+ def test_run_workflow__int_input(vellum_client):
194
+ # GIVEN a Subworkflow Deployment Node
195
+ class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
196
+ deployment = "example_subworkflow_deployment"
197
+ subworkflow_inputs = {
198
+ "number": 42,
199
+ }
200
+
201
+ # AND we know what the Subworkflow Deployment will respond with
202
+ def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
203
+ execution_id = str(uuid4())
204
+ expected_events: List[WorkflowStreamEvent] = [
205
+ WorkflowExecutionWorkflowResultEvent(
206
+ execution_id=execution_id,
207
+ data=WorkflowResultEvent(
208
+ id=str(uuid4()),
209
+ state="INITIATED",
210
+ ts=datetime.now(),
211
+ ),
212
+ ),
213
+ WorkflowExecutionWorkflowResultEvent(
214
+ execution_id=execution_id,
215
+ data=WorkflowResultEvent(
216
+ id=str(uuid4()),
217
+ state="FULFILLED",
218
+ ts=datetime.now(),
219
+ outputs=[
220
+ WorkflowOutputString(
221
+ id=str(uuid4()),
222
+ name="greeting",
223
+ value="Great!",
224
+ )
225
+ ],
226
+ ),
227
+ ),
228
+ ]
229
+ yield from expected_events
230
+
231
+ vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
232
+
233
+ # WHEN we run the node
234
+ node = ExampleSubworkflowDeploymentNode()
235
+ events = list(node.run())
236
+
237
+ # THEN the node should have completed successfully
238
+ assert events[-1].name == "greeting"
239
+ assert events[-1].value == "Great!"
240
+
241
+ # AND we should have invoked the Subworkflow Deployment with the expected inputs
242
+ call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
243
+ assert call_kwargs["inputs"] == [
244
+ WorkflowRequestNumberInputRequest(name="number", value=42),
245
+ ]
246
+
247
+
137
248
  def test_run_workflow__no_deployment():
138
249
  """Confirm that we raise error when running a subworkflow deployment node with no deployment attribute set"""
139
250