vellum-ai 0.14.11__py3-none-any.whl → 0.14.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/workflows/descriptors/base.py +3 -0
- vellum/workflows/events/workflow.py +23 -0
- vellum/workflows/inputs/base.py +26 -18
- vellum/workflows/inputs/tests/test_inputs.py +1 -1
- vellum/workflows/nodes/bases/base.py +7 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +7 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +32 -0
- vellum/workflows/nodes/core/map_node/node.py +28 -7
- vellum/workflows/nodes/core/map_node/tests/test_node.py +31 -0
- vellum/workflows/nodes/core/try_node/node.py +7 -0
- vellum/workflows/nodes/core/try_node/tests/test_node.py +32 -0
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +5 -4
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +111 -0
- vellum/workflows/nodes/mocks.py +229 -2
- vellum/workflows/nodes/tests/__init__.py +0 -0
- vellum/workflows/nodes/tests/test_mocks.py +207 -0
- vellum/workflows/outputs/base.py +1 -1
- {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/METADATA +2 -2
- {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/RECORD +28 -26
- vellum_ee/workflows/display/nodes/base_node_display.py +20 -4
- vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +24 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +2 -2
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +20 -0
- {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.11.dist-info → vellum_ai-0.14.13.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.13",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from uuid import UUID
|
2
2
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, Iterable, Literal, Optional, Type, Union
|
3
|
+
from typing_extensions import TypeGuard
|
3
4
|
|
4
5
|
from pydantic import field_serializer
|
5
6
|
|
@@ -182,3 +183,25 @@ WorkflowEvent = Union[
|
|
182
183
|
]
|
183
184
|
|
184
185
|
WorkflowEventStream = Generator[WorkflowEvent, None, None]
|
186
|
+
|
187
|
+
WorkflowExecutionEvent = Union[
|
188
|
+
WorkflowExecutionInitiatedEvent,
|
189
|
+
WorkflowExecutionStreamingEvent,
|
190
|
+
WorkflowExecutionRejectedEvent,
|
191
|
+
WorkflowExecutionPausedEvent,
|
192
|
+
WorkflowExecutionResumedEvent,
|
193
|
+
WorkflowExecutionFulfilledEvent,
|
194
|
+
WorkflowExecutionSnapshottedEvent,
|
195
|
+
]
|
196
|
+
|
197
|
+
|
198
|
+
def is_workflow_event(event: WorkflowEvent) -> TypeGuard[WorkflowExecutionEvent]:
|
199
|
+
return (
|
200
|
+
event.name == "workflow.execution.initiated"
|
201
|
+
or event.name == "workflow.execution.fulfilled"
|
202
|
+
or event.name == "workflow.execution.streaming"
|
203
|
+
or event.name == "workflow.execution.snapshotted"
|
204
|
+
or event.name == "workflow.execution.paused"
|
205
|
+
or event.name == "workflow.execution.resumed"
|
206
|
+
or event.name == "workflow.execution.rejected"
|
207
|
+
)
|
vellum/workflows/inputs/base.py
CHANGED
@@ -42,38 +42,46 @@ class BaseInputs(metaclass=_BaseInputsMeta):
|
|
42
42
|
__parent_class__: Type = type(None)
|
43
43
|
|
44
44
|
def __init__(self, **kwargs: Any) -> None:
|
45
|
+
"""
|
46
|
+
Initialize BaseInputs with provided keyword arguments.
|
47
|
+
|
48
|
+
Validation logic:
|
49
|
+
1. Ensures all required fields (non-Optional types) either:
|
50
|
+
- Have a value provided in kwargs, or
|
51
|
+
- Have a default value defined in the class
|
52
|
+
2. Validates that no None values are provided for required fields
|
53
|
+
3. Sets all provided values as attributes on the instance
|
54
|
+
|
55
|
+
Args:
|
56
|
+
**kwargs: Keyword arguments corresponding to the class's type annotations
|
57
|
+
|
58
|
+
Raises:
|
59
|
+
WorkflowInitializationException: If a required field is missing or None
|
60
|
+
"""
|
45
61
|
for name, field_type in self.__class__.__annotations__.items():
|
46
|
-
|
62
|
+
# Get the value (either from kwargs or class default)
|
63
|
+
value = kwargs.get(name)
|
64
|
+
has_default = name in vars(self.__class__)
|
65
|
+
|
66
|
+
if value is None and not has_default:
|
67
|
+
# Check if field_type allows None
|
47
68
|
origin = get_origin(field_type)
|
48
69
|
args = get_args(field_type)
|
49
70
|
if not (origin is Union and type(None) in args):
|
50
71
|
raise WorkflowInitializationException(
|
51
|
-
message="Required input variables should have defined value",
|
72
|
+
message=f"Required input variables {name} should have defined value",
|
52
73
|
code=WorkflowErrorCode.INVALID_INPUTS,
|
53
74
|
)
|
54
75
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
self._validate_input(value, field_type)
|
59
|
-
setattr(self, name, value)
|
76
|
+
# If value provided in kwargs, set it on the instance
|
77
|
+
if name in kwargs:
|
78
|
+
setattr(self, name, value)
|
60
79
|
|
61
80
|
def __iter__(self) -> Iterator[Tuple[InputReference, Any]]:
|
62
81
|
for input_descriptor in self.__class__:
|
63
82
|
if hasattr(self, input_descriptor.name):
|
64
83
|
yield (input_descriptor, getattr(self, input_descriptor.name))
|
65
84
|
|
66
|
-
def _validate_input(self, value: Any, field_type: Any) -> None:
|
67
|
-
if value is None:
|
68
|
-
# Check if field_type is Optional
|
69
|
-
origin = get_origin(field_type)
|
70
|
-
args = get_args(field_type)
|
71
|
-
if not (origin is Union and type(None) in args):
|
72
|
-
raise WorkflowInitializationException(
|
73
|
-
message="Required input variables should have defined value",
|
74
|
-
code=WorkflowErrorCode.INVALID_INPUTS,
|
75
|
-
)
|
76
|
-
|
77
85
|
@classmethod
|
78
86
|
def __get_pydantic_core_schema__(
|
79
87
|
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
|
@@ -34,7 +34,7 @@ def test_base_inputs_empty_value():
|
|
34
34
|
|
35
35
|
# THEN it should raise a NodeException with the correct error message and code
|
36
36
|
assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
|
37
|
-
assert "Required input variables should have defined value"
|
37
|
+
assert "Required input variables required_string should have defined value" == str(exc_info.value)
|
38
38
|
|
39
39
|
|
40
40
|
def test_base_inputs_with_default():
|
@@ -19,6 +19,7 @@ from vellum.workflows.ports.port import Port
|
|
19
19
|
from vellum.workflows.references import ExternalInputReference
|
20
20
|
from vellum.workflows.references.execution_count import ExecutionCountReference
|
21
21
|
from vellum.workflows.references.node import NodeReference
|
22
|
+
from vellum.workflows.references.output import OutputReference
|
22
23
|
from vellum.workflows.state.base import BaseState
|
23
24
|
from vellum.workflows.state.context import WorkflowContext
|
24
25
|
from vellum.workflows.types.core import MergeBehavior
|
@@ -118,6 +119,11 @@ class BaseNodeMeta(type):
|
|
118
119
|
node_class.Trigger.node_class = node_class
|
119
120
|
node_class.ExternalInputs.__parent_class__ = node_class
|
120
121
|
node_class.__id__ = uuid4_from_hash(node_class.__qualname__)
|
122
|
+
node_class.__output_ids__ = {
|
123
|
+
ref.name: uuid4_from_hash(f"{node_class.__id__}|{ref.name}")
|
124
|
+
for ref in node_class.Outputs
|
125
|
+
if isinstance(ref, OutputReference)
|
126
|
+
}
|
121
127
|
return node_class
|
122
128
|
|
123
129
|
@property
|
@@ -236,6 +242,7 @@ NodeRunResponse = Union[BaseOutputs, Iterator[BaseOutput]]
|
|
236
242
|
|
237
243
|
class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
238
244
|
__id__: UUID = uuid4_from_hash(__qualname__)
|
245
|
+
__output_ids__: Dict[str, UUID] = {}
|
239
246
|
state: StateType
|
240
247
|
_context: WorkflowContext
|
241
248
|
_inputs: MappingProxyType[NodeReference, Any]
|
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, Option
|
|
3
3
|
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.context import execution_context, get_parent_context
|
5
5
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
6
|
+
from vellum.workflows.events.workflow import is_workflow_event
|
6
7
|
from vellum.workflows.exceptions import NodeException
|
7
8
|
from vellum.workflows.inputs.base import BaseInputs
|
8
9
|
from vellum.workflows.nodes.bases.base import BaseNode, BaseNodeMeta
|
@@ -86,6 +87,12 @@ class InlineSubworkflowNode(
|
|
86
87
|
|
87
88
|
for event in subworkflow_stream:
|
88
89
|
self._context._emit_subworkflow_event(event)
|
90
|
+
|
91
|
+
if not is_workflow_event(event):
|
92
|
+
continue
|
93
|
+
if event.workflow_definition != self.subworkflow:
|
94
|
+
continue
|
95
|
+
|
89
96
|
if event.name == "workflow.execution.streaming":
|
90
97
|
if event.output.is_fulfilled:
|
91
98
|
fulfilled_output_names.add(event.output.name)
|
@@ -3,6 +3,7 @@ import pytest
|
|
3
3
|
from vellum.workflows.inputs.base import BaseInputs
|
4
4
|
from vellum.workflows.nodes.bases.base import BaseNode
|
5
5
|
from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
|
6
|
+
from vellum.workflows.nodes.core.try_node.node import TryNode
|
6
7
|
from vellum.workflows.outputs.base import BaseOutput
|
7
8
|
from vellum.workflows.state.base import BaseState
|
8
9
|
from vellum.workflows.workflows.base import BaseWorkflow
|
@@ -55,3 +56,34 @@ def test_inline_subworkflow_node__support_inputs_as_attributes():
|
|
55
56
|
assert events == [
|
56
57
|
BaseOutput(name="out", value="bar"),
|
57
58
|
]
|
59
|
+
|
60
|
+
|
61
|
+
def test_inline_subworkflow_node__nested_try():
|
62
|
+
"""
|
63
|
+
Ensure that the nested try node doesn't affect the subworkflow node's outputs
|
64
|
+
"""
|
65
|
+
|
66
|
+
# GIVEN a nested try node
|
67
|
+
@TryNode.wrap()
|
68
|
+
class InnerNode(BaseNode):
|
69
|
+
class Outputs:
|
70
|
+
foo = "hello"
|
71
|
+
|
72
|
+
# AND a subworkflow
|
73
|
+
class Subworkflow(BaseWorkflow):
|
74
|
+
graph = InnerNode
|
75
|
+
|
76
|
+
class Outputs(BaseWorkflow.Outputs):
|
77
|
+
bar = InnerNode.Outputs.foo
|
78
|
+
|
79
|
+
# AND an outer try node referencing that subworkflow
|
80
|
+
class OuterNode(InlineSubworkflowNode):
|
81
|
+
subworkflow = Subworkflow
|
82
|
+
|
83
|
+
# WHEN we run the try node
|
84
|
+
stream = OuterNode().run()
|
85
|
+
events = list(stream)
|
86
|
+
|
87
|
+
# THEN we only have the outer node's outputs
|
88
|
+
valid_events = [e for e in events if e.name == "bar"]
|
89
|
+
assert len(valid_events) == len(events)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
+
import logging
|
2
3
|
from queue import Empty, Queue
|
3
4
|
from threading import Thread
|
4
5
|
from typing import (
|
@@ -19,6 +20,7 @@ from typing import (
|
|
19
20
|
from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
|
20
21
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
21
22
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
23
|
+
from vellum.workflows.events.workflow import is_workflow_event
|
22
24
|
from vellum.workflows.exceptions import NodeException
|
23
25
|
from vellum.workflows.inputs.base import BaseInputs
|
24
26
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
@@ -33,6 +35,8 @@ from vellum.workflows.workflows.event_filters import all_workflow_event_filter
|
|
33
35
|
if TYPE_CHECKING:
|
34
36
|
from vellum.workflows.events.workflow import WorkflowEvent
|
35
37
|
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
|
36
40
|
MapNodeItemType = TypeVar("MapNodeItemType")
|
37
41
|
|
38
42
|
|
@@ -104,19 +108,36 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
104
108
|
subworkflow_event = map_node_event[1]
|
105
109
|
self._context._emit_subworkflow_event(subworkflow_event)
|
106
110
|
|
111
|
+
if not is_workflow_event(subworkflow_event):
|
112
|
+
continue
|
113
|
+
|
114
|
+
if subworkflow_event.workflow_definition != self.subworkflow:
|
115
|
+
continue
|
116
|
+
|
107
117
|
if subworkflow_event.name == "workflow.execution.initiated":
|
108
118
|
for output_name in mapped_items.keys():
|
109
119
|
yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
|
110
120
|
|
111
121
|
elif subworkflow_event.name == "workflow.execution.fulfilled":
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
122
|
+
for output_reference, output_value in subworkflow_event.outputs:
|
123
|
+
if not isinstance(output_reference, OutputReference):
|
124
|
+
logger.error(
|
125
|
+
"Invalid key to map node's subworkflow event outputs",
|
126
|
+
extra={"output_reference_type": type(output_reference)},
|
127
|
+
)
|
128
|
+
continue
|
129
|
+
|
130
|
+
output_mapped_items = mapped_items[output_reference.name]
|
131
|
+
if index < 0 or index >= len(output_mapped_items):
|
132
|
+
logger.error(
|
133
|
+
"Invalid map node index", extra={"index": index, "output_name": output_reference.name}
|
134
|
+
)
|
135
|
+
continue
|
136
|
+
|
137
|
+
output_mapped_items[index] = output_value
|
117
138
|
yield BaseOutput(
|
118
|
-
name=
|
119
|
-
delta=(
|
139
|
+
name=output_reference.name,
|
140
|
+
delta=(output_value, index, "FULFILLED"),
|
120
141
|
)
|
121
142
|
|
122
143
|
fulfilled_iterations[index] = True
|
@@ -3,8 +3,10 @@ import time
|
|
3
3
|
from vellum.workflows.inputs.base import BaseInputs
|
4
4
|
from vellum.workflows.nodes.bases import BaseNode
|
5
5
|
from vellum.workflows.nodes.core.map_node.node import MapNode
|
6
|
+
from vellum.workflows.nodes.core.try_node.node import TryNode
|
6
7
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
7
8
|
from vellum.workflows.state.base import BaseState, StateMeta
|
9
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
8
10
|
|
9
11
|
|
10
12
|
def test_map_node__use_parent_inputs_and_state():
|
@@ -85,3 +87,32 @@ def test_map_node__empty_list():
|
|
85
87
|
# THEN the node should return an empty output
|
86
88
|
fulfilled_output = outputs[-1]
|
87
89
|
assert fulfilled_output == BaseOutput(name="value", value=[])
|
90
|
+
|
91
|
+
|
92
|
+
def test_map_node__inner_try():
|
93
|
+
# GIVEN a try wrapped node
|
94
|
+
@TryNode.wrap()
|
95
|
+
class InnerNode(BaseNode):
|
96
|
+
class Outputs(BaseNode.Outputs):
|
97
|
+
foo: str
|
98
|
+
|
99
|
+
# AND a workflow using that node
|
100
|
+
class SimpleMapNodeWorkflow(BaseWorkflow[MapNode.SubworkflowInputs, BaseState]):
|
101
|
+
graph = InnerNode
|
102
|
+
|
103
|
+
class Outputs(BaseWorkflow.Outputs):
|
104
|
+
final_output = InnerNode.Outputs.foo
|
105
|
+
|
106
|
+
# AND a map node referencing that workflow
|
107
|
+
class SimpleMapNode(MapNode):
|
108
|
+
items = ["hello", "world"]
|
109
|
+
subworkflow = SimpleMapNodeWorkflow
|
110
|
+
max_concurrency = 4
|
111
|
+
|
112
|
+
# WHEN we run the workflow
|
113
|
+
stream = SimpleMapNode().run()
|
114
|
+
outputs = list(stream)
|
115
|
+
|
116
|
+
# THEN the workflow should succeed
|
117
|
+
assert outputs[-1].name == "final_output"
|
118
|
+
assert len(outputs[-1].value) == 2
|
@@ -2,6 +2,7 @@ from typing import Callable, Generic, Iterator, Optional, Set, Type
|
|
2
2
|
|
3
3
|
from vellum.workflows.context import execution_context, get_parent_context
|
4
4
|
from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
|
5
|
+
from vellum.workflows.events.workflow import is_workflow_event
|
5
6
|
from vellum.workflows.exceptions import NodeException
|
6
7
|
from vellum.workflows.nodes.bases import BaseNode
|
7
8
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
@@ -47,6 +48,12 @@ class TryNode(BaseAdornmentNode[StateType], Generic[StateType]):
|
|
47
48
|
if exception:
|
48
49
|
continue
|
49
50
|
|
51
|
+
if not is_workflow_event(event):
|
52
|
+
continue
|
53
|
+
|
54
|
+
if event.workflow_definition != self.subworkflow:
|
55
|
+
continue
|
56
|
+
|
50
57
|
if event.name == "workflow.execution.streaming":
|
51
58
|
if event.output.is_fulfilled:
|
52
59
|
fulfilled_output_names.add(event.output.name)
|
@@ -10,6 +10,7 @@ from vellum.workflows.outputs import BaseOutputs
|
|
10
10
|
from vellum.workflows.outputs.base import BaseOutput
|
11
11
|
from vellum.workflows.state.base import BaseState, StateMeta
|
12
12
|
from vellum.workflows.state.context import WorkflowContext
|
13
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
13
14
|
|
14
15
|
|
15
16
|
def test_try_node__on_error_code__successfully_caught():
|
@@ -126,3 +127,34 @@ def test_try_node__resolved_inputs():
|
|
126
127
|
foo = State.counter
|
127
128
|
|
128
129
|
assert MyNode.foo.types == (float,)
|
130
|
+
|
131
|
+
|
132
|
+
def test_try_node__nested_try():
|
133
|
+
"""
|
134
|
+
Ensure that the nested try node doesn't affect the outer try node's outputs
|
135
|
+
"""
|
136
|
+
|
137
|
+
# GIVEN a nested try node
|
138
|
+
@TryNode.wrap()
|
139
|
+
class InnerNode(BaseNode):
|
140
|
+
class Outputs:
|
141
|
+
foo = "hello"
|
142
|
+
|
143
|
+
# AND a subworkflow
|
144
|
+
class Subworkflow(BaseWorkflow):
|
145
|
+
graph = InnerNode
|
146
|
+
|
147
|
+
class Outputs(BaseWorkflow.Outputs):
|
148
|
+
bar = InnerNode.Outputs.foo
|
149
|
+
|
150
|
+
# AND an outer try node referencing that subworkflow
|
151
|
+
class OuterNode(TryNode):
|
152
|
+
subworkflow = Subworkflow
|
153
|
+
|
154
|
+
# WHEN we run the try node
|
155
|
+
stream = OuterNode().run()
|
156
|
+
events = list(stream)
|
157
|
+
|
158
|
+
# THEN we only have the outer node's outputs
|
159
|
+
valid_events = [e for e in events if e.name == "bar"]
|
160
|
+
assert len(valid_events) == len(events)
|
@@ -70,8 +70,10 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
70
70
|
value=input_value,
|
71
71
|
)
|
72
72
|
)
|
73
|
-
elif
|
74
|
-
isinstance(
|
73
|
+
elif (
|
74
|
+
isinstance(input_value, list)
|
75
|
+
and len(input_value) > 0
|
76
|
+
and all(isinstance(message, (ChatMessage, ChatMessageRequest)) for message in input_value)
|
75
77
|
):
|
76
78
|
chat_history = [
|
77
79
|
(
|
@@ -95,7 +97,7 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
95
97
|
value=cast(Dict[str, Any], input_value),
|
96
98
|
)
|
97
99
|
)
|
98
|
-
elif isinstance(input_value, float):
|
100
|
+
elif isinstance(input_value, (int, float)):
|
99
101
|
compiled_inputs.append(
|
100
102
|
WorkflowRequestNumberInputRequest(
|
101
103
|
name=input_name,
|
@@ -110,7 +112,6 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
110
112
|
message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
|
111
113
|
code=WorkflowErrorCode.INVALID_INPUTS,
|
112
114
|
)
|
113
|
-
|
114
115
|
compiled_inputs.append(
|
115
116
|
WorkflowRequestJsonInputRequest(
|
116
117
|
name=input_name,
|
@@ -10,6 +10,7 @@ from vellum.client.types.workflow_execution_workflow_result_event import Workflo
|
|
10
10
|
from vellum.client.types.workflow_output_string import WorkflowOutputString
|
11
11
|
from vellum.client.types.workflow_request_chat_history_input_request import WorkflowRequestChatHistoryInputRequest
|
12
12
|
from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
|
13
|
+
from vellum.client.types.workflow_request_number_input_request import WorkflowRequestNumberInputRequest
|
13
14
|
from vellum.client.types.workflow_result_event import WorkflowResultEvent
|
14
15
|
from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
|
15
16
|
from vellum.workflows.errors import WorkflowErrorCode
|
@@ -134,6 +135,116 @@ def test_run_workflow__any_array(vellum_client):
|
|
134
135
|
]
|
135
136
|
|
136
137
|
|
138
|
+
def test_run_workflow__empty_array(vellum_client):
|
139
|
+
# GIVEN a Subworkflow Deployment Node
|
140
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
141
|
+
deployment = "example_subworkflow_deployment"
|
142
|
+
subworkflow_inputs = {
|
143
|
+
"fruits": [],
|
144
|
+
}
|
145
|
+
|
146
|
+
# AND we know what the Subworkflow Deployment will respond with
|
147
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
148
|
+
execution_id = str(uuid4())
|
149
|
+
expected_events: List[WorkflowStreamEvent] = [
|
150
|
+
WorkflowExecutionWorkflowResultEvent(
|
151
|
+
execution_id=execution_id,
|
152
|
+
data=WorkflowResultEvent(
|
153
|
+
id=str(uuid4()),
|
154
|
+
state="INITIATED",
|
155
|
+
ts=datetime.now(),
|
156
|
+
),
|
157
|
+
),
|
158
|
+
WorkflowExecutionWorkflowResultEvent(
|
159
|
+
execution_id=execution_id,
|
160
|
+
data=WorkflowResultEvent(
|
161
|
+
id=str(uuid4()),
|
162
|
+
state="FULFILLED",
|
163
|
+
ts=datetime.now(),
|
164
|
+
outputs=[
|
165
|
+
WorkflowOutputString(
|
166
|
+
id=str(uuid4()),
|
167
|
+
name="greeting",
|
168
|
+
value="Great!",
|
169
|
+
)
|
170
|
+
],
|
171
|
+
),
|
172
|
+
),
|
173
|
+
]
|
174
|
+
yield from expected_events
|
175
|
+
|
176
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
177
|
+
|
178
|
+
# WHEN we run the node
|
179
|
+
node = ExampleSubworkflowDeploymentNode()
|
180
|
+
events = list(node.run())
|
181
|
+
|
182
|
+
# THEN the node should have completed successfully
|
183
|
+
assert events[-1].name == "greeting"
|
184
|
+
assert events[-1].value == "Great!"
|
185
|
+
|
186
|
+
# AND we should have invoked the Subworkflow Deployment with the expected inputs
|
187
|
+
call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
|
188
|
+
assert call_kwargs["inputs"] == [
|
189
|
+
WorkflowRequestJsonInputRequest(name="fruits", value=[]),
|
190
|
+
]
|
191
|
+
|
192
|
+
|
193
|
+
def test_run_workflow__int_input(vellum_client):
|
194
|
+
# GIVEN a Subworkflow Deployment Node
|
195
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
196
|
+
deployment = "example_subworkflow_deployment"
|
197
|
+
subworkflow_inputs = {
|
198
|
+
"number": 42,
|
199
|
+
}
|
200
|
+
|
201
|
+
# AND we know what the Subworkflow Deployment will respond with
|
202
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
203
|
+
execution_id = str(uuid4())
|
204
|
+
expected_events: List[WorkflowStreamEvent] = [
|
205
|
+
WorkflowExecutionWorkflowResultEvent(
|
206
|
+
execution_id=execution_id,
|
207
|
+
data=WorkflowResultEvent(
|
208
|
+
id=str(uuid4()),
|
209
|
+
state="INITIATED",
|
210
|
+
ts=datetime.now(),
|
211
|
+
),
|
212
|
+
),
|
213
|
+
WorkflowExecutionWorkflowResultEvent(
|
214
|
+
execution_id=execution_id,
|
215
|
+
data=WorkflowResultEvent(
|
216
|
+
id=str(uuid4()),
|
217
|
+
state="FULFILLED",
|
218
|
+
ts=datetime.now(),
|
219
|
+
outputs=[
|
220
|
+
WorkflowOutputString(
|
221
|
+
id=str(uuid4()),
|
222
|
+
name="greeting",
|
223
|
+
value="Great!",
|
224
|
+
)
|
225
|
+
],
|
226
|
+
),
|
227
|
+
),
|
228
|
+
]
|
229
|
+
yield from expected_events
|
230
|
+
|
231
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
232
|
+
|
233
|
+
# WHEN we run the node
|
234
|
+
node = ExampleSubworkflowDeploymentNode()
|
235
|
+
events = list(node.run())
|
236
|
+
|
237
|
+
# THEN the node should have completed successfully
|
238
|
+
assert events[-1].name == "greeting"
|
239
|
+
assert events[-1].value == "Great!"
|
240
|
+
|
241
|
+
# AND we should have invoked the Subworkflow Deployment with the expected inputs
|
242
|
+
call_kwargs = vellum_client.execute_workflow_stream.call_args.kwargs
|
243
|
+
assert call_kwargs["inputs"] == [
|
244
|
+
WorkflowRequestNumberInputRequest(name="number", value=42),
|
245
|
+
]
|
246
|
+
|
247
|
+
|
137
248
|
def test_run_workflow__no_deployment():
|
138
249
|
"""Confirm that we raise error when running a subworkflow deployment node with no deployment attribute set"""
|
139
250
|
|