vellum-ai 0.14.12__py3-none-any.whl → 0.14.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/plugins/utils.py +11 -3
- vellum/workflows/descriptors/base.py +3 -0
- vellum/workflows/events/node.py +5 -0
- vellum/workflows/events/tests/test_event.py +36 -0
- vellum/workflows/events/workflow.py +23 -0
- vellum/workflows/inputs/base.py +26 -18
- vellum/workflows/inputs/tests/test_inputs.py +1 -1
- vellum/workflows/nodes/bases/base.py +7 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +7 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +32 -0
- vellum/workflows/nodes/core/map_node/node.py +28 -7
- vellum/workflows/nodes/core/map_node/tests/test_node.py +31 -0
- vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +16 -0
- vellum/workflows/nodes/core/try_node/node.py +7 -0
- vellum/workflows/nodes/core/try_node/tests/test_node.py +32 -0
- vellum/workflows/nodes/mocks.py +229 -2
- vellum/workflows/nodes/tests/__init__.py +0 -0
- vellum/workflows/nodes/tests/test_mocks.py +207 -0
- vellum/workflows/nodes/tests/test_utils.py +133 -0
- vellum/workflows/nodes/utils.py +17 -3
- vellum/workflows/outputs/base.py +1 -1
- vellum/workflows/runner/runner.py +2 -1
- {vellum_ai-0.14.12.dist-info → vellum_ai-0.14.14.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.12.dist-info → vellum_ai-0.14.14.dist-info}/RECORD +33 -30
- vellum_ee/workflows/display/nodes/base_node_display.py +20 -4
- vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +24 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +2 -2
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +20 -0
- {vellum_ai-0.14.12.dist-info → vellum_ai-0.14.14.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.12.dist-info → vellum_ai-0.14.14.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.12.dist-info → vellum_ai-0.14.14.dist-info}/entry_points.txt +0 -0
vellum/workflows/nodes/mocks.py
CHANGED
@@ -1,10 +1,102 @@
|
|
1
|
-
from
|
1
|
+
from functools import reduce
|
2
|
+
from uuid import UUID
|
3
|
+
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence, Type, Union
|
2
4
|
|
3
|
-
from pydantic import ConfigDict
|
5
|
+
from pydantic import ConfigDict, ValidationError
|
4
6
|
|
5
7
|
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
8
|
+
from vellum.client.types.array_vellum_value import ArrayVellumValue
|
9
|
+
from vellum.client.types.vellum_value import VellumValue
|
6
10
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
11
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
12
|
+
from vellum.workflows.exceptions import WorkflowInitializationException
|
7
13
|
from vellum.workflows.outputs.base import BaseOutputs
|
14
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from vellum.workflows import BaseWorkflow
|
18
|
+
|
19
|
+
import logging
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class _RawLogicalCondition(UniversalBaseModel):
|
25
|
+
type: Literal["LOGICAL_CONDITION"] = "LOGICAL_CONDITION"
|
26
|
+
lhs_variable_id: UUID
|
27
|
+
operator: Literal["==", ">", ">=", "<", "<=", "!="]
|
28
|
+
rhs_variable_id: UUID
|
29
|
+
|
30
|
+
|
31
|
+
class _RawLogicalConditionGroup(UniversalBaseModel):
|
32
|
+
type: Literal["LOGICAL_CONDITION_GROUP"] = "LOGICAL_CONDITION_GROUP"
|
33
|
+
conditions: List["_RawLogicalExpression"]
|
34
|
+
combinator: Literal["AND", "OR"]
|
35
|
+
negated: bool
|
36
|
+
|
37
|
+
|
38
|
+
_RawLogicalExpression = Union[_RawLogicalCondition, _RawLogicalConditionGroup]
|
39
|
+
|
40
|
+
|
41
|
+
class _RawLogicalExpressionVariable(UniversalBaseModel):
|
42
|
+
id: UUID
|
43
|
+
|
44
|
+
|
45
|
+
class _RawMockWorkflowNodeExecutionConstantValuePointer(_RawLogicalExpressionVariable):
|
46
|
+
type: Literal["CONSTANT_VALUE"] = "CONSTANT_VALUE"
|
47
|
+
variable_value: VellumValue
|
48
|
+
|
49
|
+
|
50
|
+
class _RawMockWorkflowNodeExecutionNodeExecutionCounterPointer(_RawLogicalExpressionVariable):
|
51
|
+
type: Literal["EXECUTION_COUNTER"] = "EXECUTION_COUNTER"
|
52
|
+
node_id: UUID
|
53
|
+
|
54
|
+
|
55
|
+
class _RawMockWorkflowNodeExecutionInputVariablePointer(_RawLogicalExpressionVariable):
|
56
|
+
type: Literal["INPUT_VARIABLE"] = "INPUT_VARIABLE"
|
57
|
+
input_variable_id: UUID
|
58
|
+
|
59
|
+
|
60
|
+
class _RawMockWorkflowNodeExecutionNodeOutputPointer(_RawLogicalExpressionVariable):
|
61
|
+
type: Literal["NODE_OUTPUT"] = "NODE_OUTPUT"
|
62
|
+
node_id: UUID
|
63
|
+
input_id: UUID
|
64
|
+
|
65
|
+
|
66
|
+
class _RawMockWorkflowNodeExecutionNodeInputPointer(_RawLogicalExpressionVariable):
|
67
|
+
type: Literal["NODE_INPUT"] = "NODE_INPUT"
|
68
|
+
node_id: UUID
|
69
|
+
input_id: UUID
|
70
|
+
|
71
|
+
|
72
|
+
_RawMockWorkflowNodeExecutionValuePointer = Union[
|
73
|
+
_RawMockWorkflowNodeExecutionConstantValuePointer,
|
74
|
+
_RawMockWorkflowNodeExecutionNodeExecutionCounterPointer,
|
75
|
+
_RawMockWorkflowNodeExecutionInputVariablePointer,
|
76
|
+
_RawMockWorkflowNodeExecutionNodeOutputPointer,
|
77
|
+
_RawMockWorkflowNodeExecutionNodeInputPointer,
|
78
|
+
]
|
79
|
+
|
80
|
+
|
81
|
+
class _RawMockWorkflowNodeWhenCondition(UniversalBaseModel):
|
82
|
+
expression: _RawLogicalExpression
|
83
|
+
variables: List[_RawMockWorkflowNodeExecutionValuePointer]
|
84
|
+
|
85
|
+
|
86
|
+
class _RawMockWorkflowNodeThenOutput(UniversalBaseModel):
|
87
|
+
output_id: UUID
|
88
|
+
value: _RawMockWorkflowNodeExecutionValuePointer
|
89
|
+
|
90
|
+
|
91
|
+
class _RawMockWorkflowNodeExecution(UniversalBaseModel):
|
92
|
+
when_condition: _RawMockWorkflowNodeWhenCondition
|
93
|
+
then_outputs: List[_RawMockWorkflowNodeThenOutput]
|
94
|
+
|
95
|
+
|
96
|
+
class _RawMockWorkflowNodeConfig(UniversalBaseModel):
|
97
|
+
type: Literal["WORKFLOW_NODE_OUTPUT"] = "WORKFLOW_NODE_OUTPUT"
|
98
|
+
node_id: UUID
|
99
|
+
mock_executions: List[_RawMockWorkflowNodeExecution]
|
8
100
|
|
9
101
|
|
10
102
|
class MockNodeExecution(UniversalBaseModel):
|
@@ -13,5 +105,140 @@ class MockNodeExecution(UniversalBaseModel):
|
|
13
105
|
|
14
106
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
15
107
|
|
108
|
+
@staticmethod
|
109
|
+
def validate_all(
|
110
|
+
raw_mock_workflow_node_configs: Optional[List[Any]],
|
111
|
+
workflow: Type["BaseWorkflow"],
|
112
|
+
) -> Optional[List["MockNodeExecution"]]:
|
113
|
+
if not raw_mock_workflow_node_configs:
|
114
|
+
return None
|
115
|
+
|
116
|
+
ArrayVellumValue.model_rebuild()
|
117
|
+
try:
|
118
|
+
mock_workflow_node_configs = [
|
119
|
+
_RawMockWorkflowNodeConfig.model_validate(raw_mock_workflow_node_config)
|
120
|
+
for raw_mock_workflow_node_config in raw_mock_workflow_node_configs
|
121
|
+
]
|
122
|
+
except ValidationError as e:
|
123
|
+
raise WorkflowInitializationException(
|
124
|
+
message="Failed to validate mock node executions",
|
125
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
126
|
+
) from e
|
127
|
+
|
128
|
+
nodes = {node.__id__: node for node in workflow.get_nodes()}
|
129
|
+
node_output_name_by_id = {
|
130
|
+
node.__output_ids__[output.name]: output.name for node in workflow.get_nodes() for output in node.Outputs
|
131
|
+
}
|
132
|
+
|
133
|
+
# We need to support the old way that the Vellum App's WorkflowRunner used to define Node Mocks in order to
|
134
|
+
# avoid needing to update the mock resolution strategy that it and the frontend uses. The path towards
|
135
|
+
# cleaning this up will go as follows:
|
136
|
+
# 1. Release Mock support in SDK-Enabled Workflows
|
137
|
+
# 2. Deprecate Mock support in non-SDK enabled Workflows, encouraging users to migrate to SDK-enabled Workflows
|
138
|
+
# 3. Remove the old mock resolution strategy
|
139
|
+
# 4. Update this SDK to handle the new mock resolution strategy with WorkflowValueDescriptors
|
140
|
+
# 5. Cutover the Vellum App to the new mock resolution strategy
|
141
|
+
# 6. Remove the old mock resolution strategy from this SDK
|
142
|
+
def _translate_raw_logical_expression(
|
143
|
+
raw_logical_expression: _RawLogicalExpression,
|
144
|
+
raw_variables: List[_RawMockWorkflowNodeExecutionValuePointer],
|
145
|
+
) -> BaseDescriptor:
|
146
|
+
if raw_logical_expression.type == "LOGICAL_CONDITION":
|
147
|
+
return _translate_raw_logical_condition(raw_logical_expression, raw_variables)
|
148
|
+
else:
|
149
|
+
return _translate_raw_logical_condition_group(raw_logical_expression, raw_variables)
|
150
|
+
|
151
|
+
def _translate_raw_logical_condition_group(
|
152
|
+
raw_logical_condition_group: _RawLogicalConditionGroup,
|
153
|
+
raw_variables: List[_RawMockWorkflowNodeExecutionValuePointer],
|
154
|
+
) -> BaseDescriptor:
|
155
|
+
if not raw_logical_condition_group.conditions:
|
156
|
+
return ConstantValueReference(True)
|
157
|
+
|
158
|
+
conditions = [
|
159
|
+
_translate_raw_logical_expression(condition, raw_variables)
|
160
|
+
for condition in raw_logical_condition_group.conditions
|
161
|
+
]
|
162
|
+
return reduce(
|
163
|
+
lambda acc, condition: (
|
164
|
+
acc and condition if raw_logical_condition_group.combinator == "AND" else acc or condition
|
165
|
+
),
|
166
|
+
conditions,
|
167
|
+
)
|
168
|
+
|
169
|
+
def _translate_raw_logical_condition(
|
170
|
+
raw_logical_condition: _RawLogicalCondition,
|
171
|
+
raw_variables: List[_RawMockWorkflowNodeExecutionValuePointer],
|
172
|
+
) -> BaseDescriptor:
|
173
|
+
variable_by_id = {v.id: v for v in raw_variables}
|
174
|
+
lhs = _translate_raw_logical_expression_variable(variable_by_id[raw_logical_condition.lhs_variable_id])
|
175
|
+
rhs = _translate_raw_logical_expression_variable(variable_by_id[raw_logical_condition.rhs_variable_id])
|
176
|
+
if raw_logical_condition.operator == ">":
|
177
|
+
return lhs.greater_than(rhs)
|
178
|
+
elif raw_logical_condition.operator == ">=":
|
179
|
+
return lhs.greater_than_or_equal_to(rhs)
|
180
|
+
elif raw_logical_condition.operator == "<":
|
181
|
+
return lhs.less_than(rhs)
|
182
|
+
elif raw_logical_condition.operator == "<=":
|
183
|
+
return lhs.less_than_or_equal_to(rhs)
|
184
|
+
elif raw_logical_condition.operator == "==":
|
185
|
+
return lhs.equals(rhs)
|
186
|
+
elif raw_logical_condition.operator == "!=":
|
187
|
+
return lhs.does_not_equal(rhs)
|
188
|
+
else:
|
189
|
+
raise WorkflowInitializationException(f"Unsupported logical operator: {raw_logical_condition.operator}")
|
190
|
+
|
191
|
+
def _translate_raw_logical_expression_variable(
|
192
|
+
raw_variable: _RawMockWorkflowNodeExecutionValuePointer,
|
193
|
+
) -> BaseDescriptor:
|
194
|
+
if raw_variable.type == "CONSTANT_VALUE":
|
195
|
+
return ConstantValueReference(raw_variable.variable_value.value)
|
196
|
+
elif raw_variable.type == "EXECUTION_COUNTER":
|
197
|
+
node = nodes[raw_variable.node_id]
|
198
|
+
return node.Execution.count
|
199
|
+
else:
|
200
|
+
raise WorkflowInitializationException(f"Unsupported logical expression type: {raw_variable.type}")
|
201
|
+
|
202
|
+
mock_node_executions = []
|
203
|
+
for mock_workflow_node_config in mock_workflow_node_configs:
|
204
|
+
for mock_execution in mock_workflow_node_config.mock_executions:
|
205
|
+
try:
|
206
|
+
when_condition = _translate_raw_logical_expression(
|
207
|
+
mock_execution.when_condition.expression,
|
208
|
+
mock_execution.when_condition.variables,
|
209
|
+
)
|
210
|
+
|
211
|
+
then_outputs = nodes[mock_workflow_node_config.node_id].Outputs()
|
212
|
+
for then_output in mock_execution.then_outputs:
|
213
|
+
node_output_name = node_output_name_by_id.get(then_output.output_id)
|
214
|
+
if node_output_name is None:
|
215
|
+
raise WorkflowInitializationException(
|
216
|
+
f"Output {then_output.output_id} not found in node {mock_workflow_node_config.node_id}"
|
217
|
+
)
|
218
|
+
|
219
|
+
resolved_output_reference = _translate_raw_logical_expression_variable(then_output.value)
|
220
|
+
if isinstance(resolved_output_reference, ConstantValueReference):
|
221
|
+
setattr(
|
222
|
+
then_outputs,
|
223
|
+
node_output_name,
|
224
|
+
resolved_output_reference._value,
|
225
|
+
)
|
226
|
+
else:
|
227
|
+
raise WorkflowInitializationException(
|
228
|
+
f"Unsupported resolved output reference type: {type(resolved_output_reference)}"
|
229
|
+
)
|
230
|
+
|
231
|
+
mock_node_executions.append(
|
232
|
+
MockNodeExecution(
|
233
|
+
when_condition=when_condition,
|
234
|
+
then_outputs=then_outputs,
|
235
|
+
)
|
236
|
+
)
|
237
|
+
except Exception as e:
|
238
|
+
logger.exception("Failed to validate mock node execution", exc_info=e)
|
239
|
+
continue
|
240
|
+
|
241
|
+
return mock_node_executions
|
242
|
+
|
16
243
|
|
17
244
|
MockNodeExecutionArg = Sequence[Union[BaseOutputs, MockNodeExecution]]
|
File without changes
|
@@ -0,0 +1,207 @@
|
|
1
|
+
import uuid
|
2
|
+
|
3
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
4
|
+
from vellum.workflows import BaseWorkflow
|
5
|
+
from vellum.workflows.nodes import InlinePromptNode
|
6
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
7
|
+
from vellum.workflows.nodes.mocks import MockNodeExecution
|
8
|
+
from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
|
9
|
+
from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
|
10
|
+
|
11
|
+
|
12
|
+
def test_mocks__parse_from_app():
|
13
|
+
# GIVEN a PromptNode
|
14
|
+
class PromptNode(InlinePromptNode):
|
15
|
+
pass
|
16
|
+
|
17
|
+
# AND a workflow class with that PromptNode
|
18
|
+
class MyWorkflow(BaseWorkflow):
|
19
|
+
graph = PromptNode
|
20
|
+
|
21
|
+
# AND a mock workflow node execution from the app
|
22
|
+
raw_mock_workflow_node_execution = [
|
23
|
+
{
|
24
|
+
"type": "WORKFLOW_NODE_OUTPUT",
|
25
|
+
"node_id": str(PromptNode.__id__),
|
26
|
+
"mock_executions": [
|
27
|
+
{
|
28
|
+
"when_condition": {
|
29
|
+
"expression": {
|
30
|
+
"type": "LOGICAL_CONDITION_GROUP",
|
31
|
+
"combinator": "AND",
|
32
|
+
"negated": False,
|
33
|
+
"conditions": [
|
34
|
+
{
|
35
|
+
"type": "LOGICAL_CONDITION",
|
36
|
+
"lhs_variable_id": "e60902d5-6892-4916-80c1-f0130af52322",
|
37
|
+
"operator": ">=",
|
38
|
+
"rhs_variable_id": "5c1bbb24-c288-49cb-a9b7-0c6f38a86037",
|
39
|
+
}
|
40
|
+
],
|
41
|
+
},
|
42
|
+
"variables": [
|
43
|
+
{
|
44
|
+
"type": "EXECUTION_COUNTER",
|
45
|
+
"node_id": str(PromptNode.__id__),
|
46
|
+
"id": "e60902d5-6892-4916-80c1-f0130af52322",
|
47
|
+
},
|
48
|
+
{
|
49
|
+
"type": "CONSTANT_VALUE",
|
50
|
+
"variable_value": {"type": "NUMBER", "value": 0},
|
51
|
+
"id": "5c1bbb24-c288-49cb-a9b7-0c6f38a86037",
|
52
|
+
},
|
53
|
+
],
|
54
|
+
},
|
55
|
+
"then_outputs": [
|
56
|
+
{
|
57
|
+
"output_id": "9e6dc5d3-8ea0-4346-8a2a-7cce5495755b",
|
58
|
+
"value": {
|
59
|
+
"id": "27006b2a-fa81-430c-a0b2-c66a9351fc68",
|
60
|
+
"type": "CONSTANT_VALUE",
|
61
|
+
"variable_value": {"type": "STRING", "value": "Hello"},
|
62
|
+
},
|
63
|
+
},
|
64
|
+
{
|
65
|
+
"output_id": "60305ffd-60b0-42aa-b54e-4fdae0f8c28a",
|
66
|
+
"value": {
|
67
|
+
"id": "4559c778-6e27-4cfe-a460-734ba62a5082",
|
68
|
+
"type": "CONSTANT_VALUE",
|
69
|
+
"variable_value": {"type": "ARRAY", "value": [{"type": "STRING", "value": "Hello"}]},
|
70
|
+
},
|
71
|
+
},
|
72
|
+
],
|
73
|
+
}
|
74
|
+
],
|
75
|
+
}
|
76
|
+
]
|
77
|
+
|
78
|
+
# WHEN we parse the mock workflow node execution
|
79
|
+
node_output_mocks = MockNodeExecution.validate_all(
|
80
|
+
raw_mock_workflow_node_execution,
|
81
|
+
MyWorkflow,
|
82
|
+
)
|
83
|
+
|
84
|
+
# THEN we get a list of MockNodeExecution objects
|
85
|
+
assert node_output_mocks
|
86
|
+
assert len(node_output_mocks) == 1
|
87
|
+
assert node_output_mocks[0] == MockNodeExecution(
|
88
|
+
when_condition=PromptNode.Execution.count.greater_than_or_equal_to(0.0),
|
89
|
+
then_outputs=PromptNode.Outputs(
|
90
|
+
text="Hello",
|
91
|
+
results=[
|
92
|
+
StringVellumValue(value="Hello"),
|
93
|
+
],
|
94
|
+
),
|
95
|
+
)
|
96
|
+
|
97
|
+
|
98
|
+
def test_mocks__parse_none_still_runs():
|
99
|
+
# GIVEN a Base Node
|
100
|
+
class StartNode(BaseNode):
|
101
|
+
class Outputs(BaseNode.Outputs):
|
102
|
+
foo: str
|
103
|
+
|
104
|
+
# AND a workflow class with that Node
|
105
|
+
class MyWorkflow(BaseWorkflow):
|
106
|
+
graph = StartNode
|
107
|
+
|
108
|
+
class Outputs(BaseWorkflow.Outputs):
|
109
|
+
final_value = StartNode.Outputs.foo
|
110
|
+
|
111
|
+
# AND we parsed `None` on `MockNodeExecution`
|
112
|
+
node_output_mocks = MockNodeExecution.validate_all(
|
113
|
+
None,
|
114
|
+
MyWorkflow,
|
115
|
+
)
|
116
|
+
|
117
|
+
# WHEN we run the workflow
|
118
|
+
workflow = MyWorkflow()
|
119
|
+
final_event = workflow.run(node_output_mocks=node_output_mocks)
|
120
|
+
|
121
|
+
# THEN it was successful
|
122
|
+
assert final_event.name == "workflow.execution.fulfilled"
|
123
|
+
|
124
|
+
|
125
|
+
def test_mocks__use_id_from_display():
|
126
|
+
# GIVEN a Base Node
|
127
|
+
class StartNode(BaseNode):
|
128
|
+
class Outputs(BaseNode.Outputs):
|
129
|
+
foo: str
|
130
|
+
|
131
|
+
# AND a workflow class with that Node
|
132
|
+
class MyWorkflow(BaseWorkflow):
|
133
|
+
graph = StartNode
|
134
|
+
|
135
|
+
class Outputs(BaseWorkflow.Outputs):
|
136
|
+
final_value = StartNode.Outputs.foo
|
137
|
+
|
138
|
+
# AND a display class on that Base Node
|
139
|
+
node_output_id = uuid.uuid4()
|
140
|
+
|
141
|
+
class StartNodeDisplay(BaseNodeDisplay[StartNode]):
|
142
|
+
output_display = {StartNode.Outputs.foo: NodeOutputDisplay(id=node_output_id, name="foo")}
|
143
|
+
|
144
|
+
# AND a mock workflow node execution from the app
|
145
|
+
raw_mock_workflow_node_execution = [
|
146
|
+
{
|
147
|
+
"type": "WORKFLOW_NODE_OUTPUT",
|
148
|
+
"node_id": str(StartNode.__id__),
|
149
|
+
"mock_executions": [
|
150
|
+
{
|
151
|
+
"when_condition": {
|
152
|
+
"expression": {
|
153
|
+
"type": "LOGICAL_CONDITION_GROUP",
|
154
|
+
"combinator": "AND",
|
155
|
+
"negated": False,
|
156
|
+
"conditions": [
|
157
|
+
{
|
158
|
+
"type": "LOGICAL_CONDITION",
|
159
|
+
"lhs_variable_id": "e60902d5-6892-4916-80c1-f0130af52322",
|
160
|
+
"operator": ">=",
|
161
|
+
"rhs_variable_id": "5c1bbb24-c288-49cb-a9b7-0c6f38a86037",
|
162
|
+
}
|
163
|
+
],
|
164
|
+
},
|
165
|
+
"variables": [
|
166
|
+
{
|
167
|
+
"type": "EXECUTION_COUNTER",
|
168
|
+
"node_id": str(StartNode.__id__),
|
169
|
+
"id": "e60902d5-6892-4916-80c1-f0130af52322",
|
170
|
+
},
|
171
|
+
{
|
172
|
+
"type": "CONSTANT_VALUE",
|
173
|
+
"variable_value": {"type": "NUMBER", "value": 0},
|
174
|
+
"id": "5c1bbb24-c288-49cb-a9b7-0c6f38a86037",
|
175
|
+
},
|
176
|
+
],
|
177
|
+
},
|
178
|
+
"then_outputs": [
|
179
|
+
{
|
180
|
+
"output_id": str(node_output_id),
|
181
|
+
"value": {
|
182
|
+
"id": "27006b2a-fa81-430c-a0b2-c66a9351fc68",
|
183
|
+
"type": "CONSTANT_VALUE",
|
184
|
+
"variable_value": {"type": "STRING", "value": "Hello"},
|
185
|
+
},
|
186
|
+
},
|
187
|
+
],
|
188
|
+
}
|
189
|
+
],
|
190
|
+
}
|
191
|
+
]
|
192
|
+
|
193
|
+
# WHEN we parsed the raw data on `MockNodeExecution`
|
194
|
+
node_output_mocks = MockNodeExecution.validate_all(
|
195
|
+
raw_mock_workflow_node_execution,
|
196
|
+
MyWorkflow,
|
197
|
+
)
|
198
|
+
|
199
|
+
# THEN we get the expected list of MockNodeExecution objects
|
200
|
+
assert node_output_mocks
|
201
|
+
assert len(node_output_mocks) == 1
|
202
|
+
assert node_output_mocks[0] == MockNodeExecution(
|
203
|
+
when_condition=StartNode.Execution.count.greater_than_or_equal_to(0.0),
|
204
|
+
then_outputs=StartNode.Outputs(
|
205
|
+
foo="Hello",
|
206
|
+
),
|
207
|
+
)
|
@@ -0,0 +1,133 @@
|
|
1
|
+
import pytest
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
6
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
7
|
+
from vellum.workflows.exceptions import NodeException
|
8
|
+
from vellum.workflows.nodes.utils import parse_type_from_str
|
9
|
+
from vellum.workflows.types.core import Json
|
10
|
+
|
11
|
+
|
12
|
+
class Person(BaseModel):
|
13
|
+
name: str
|
14
|
+
age: int
|
15
|
+
|
16
|
+
|
17
|
+
class FunctionCall(BaseModel):
|
18
|
+
name: str
|
19
|
+
args: List[int]
|
20
|
+
|
21
|
+
|
22
|
+
@pytest.mark.parametrize(
|
23
|
+
"input_str, output_type, expected_result",
|
24
|
+
[
|
25
|
+
("hello", str, "hello"),
|
26
|
+
("3.14", float, 3.14),
|
27
|
+
("42", int, 42),
|
28
|
+
("True", bool, True),
|
29
|
+
("", List[str], []), # Empty string should return an empty list
|
30
|
+
("[1, 2, 3]", List[int], [1, 2, 3]),
|
31
|
+
('["a", "b", "c"]', List[str], ["a", "b", "c"]),
|
32
|
+
('{"name": "Alice", "age": 30}', Json, {"name": "Alice", "age": 30}),
|
33
|
+
(
|
34
|
+
'{"type": "FUNCTION_CALL", "value": {"name": "test", "args": [1, 2]}}',
|
35
|
+
Json,
|
36
|
+
{"name": "test", "args": [1, 2]},
|
37
|
+
),
|
38
|
+
("42", Union[int, str], 42),
|
39
|
+
("hello", Union[int, str], "hello"),
|
40
|
+
],
|
41
|
+
ids=[
|
42
|
+
"str",
|
43
|
+
"float",
|
44
|
+
"int",
|
45
|
+
"bool",
|
46
|
+
"empty_list",
|
47
|
+
"list_of_int",
|
48
|
+
"list_of_str",
|
49
|
+
"simple_json",
|
50
|
+
"function_call_json",
|
51
|
+
"union_int",
|
52
|
+
"union_str",
|
53
|
+
],
|
54
|
+
)
|
55
|
+
def test_parse_type_from_str_basic_cases(input_str, output_type, expected_result):
|
56
|
+
result = parse_type_from_str(input_str, output_type)
|
57
|
+
assert result == expected_result
|
58
|
+
|
59
|
+
|
60
|
+
def test_parse_type_from_str_pydantic_models():
|
61
|
+
person_json = '{"name": "Alice", "age": 30}'
|
62
|
+
person = parse_type_from_str(person_json, Person)
|
63
|
+
assert isinstance(person, Person)
|
64
|
+
assert person.name == "Alice"
|
65
|
+
assert person.age == 30
|
66
|
+
|
67
|
+
function_json = '{"name": "test", "args": [1, 2]}'
|
68
|
+
function = parse_type_from_str(function_json, FunctionCall)
|
69
|
+
assert isinstance(function, FunctionCall)
|
70
|
+
assert function.name == "test"
|
71
|
+
assert function.args == [1, 2]
|
72
|
+
|
73
|
+
function_call_json = '{"value": {"name": "test", "args": [1, 2]}}'
|
74
|
+
function = parse_type_from_str(function_call_json, FunctionCall)
|
75
|
+
assert isinstance(function, FunctionCall)
|
76
|
+
assert function.name == "test"
|
77
|
+
assert function.args == [1, 2]
|
78
|
+
|
79
|
+
|
80
|
+
def test_parse_type_from_str_list_of_models():
|
81
|
+
person_list_json = '[{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]'
|
82
|
+
persons = parse_type_from_str(person_list_json, List[Person])
|
83
|
+
assert len(persons) == 2
|
84
|
+
assert all(isinstance(p, Person) for p in persons)
|
85
|
+
assert persons[0].name == "Alice"
|
86
|
+
assert persons[0].age == 30
|
87
|
+
assert persons[1].name == "Bob"
|
88
|
+
assert persons[1].age == 25
|
89
|
+
|
90
|
+
|
91
|
+
@pytest.mark.parametrize(
|
92
|
+
"input_str, output_type, expected_exception, expected_code, expected_message",
|
93
|
+
[
|
94
|
+
(
|
95
|
+
"{invalid json}",
|
96
|
+
List[str],
|
97
|
+
NodeException,
|
98
|
+
WorkflowErrorCode.INVALID_OUTPUTS,
|
99
|
+
"Invalid JSON Array format for result_as_str",
|
100
|
+
),
|
101
|
+
(
|
102
|
+
"{invalid json}",
|
103
|
+
Person,
|
104
|
+
NodeException,
|
105
|
+
WorkflowErrorCode.INVALID_OUTPUTS,
|
106
|
+
"Invalid JSON format for result_as_str",
|
107
|
+
),
|
108
|
+
(
|
109
|
+
"{invalid json}",
|
110
|
+
Json,
|
111
|
+
NodeException,
|
112
|
+
WorkflowErrorCode.INVALID_OUTPUTS,
|
113
|
+
"Invalid JSON format for result_as_str",
|
114
|
+
),
|
115
|
+
('{"name": "Alice"}', List[str], ValueError, None, "Expected a list of items for result_as_str, received dict"),
|
116
|
+
("data", object, ValueError, None, "Unsupported output type: <class 'object'>"),
|
117
|
+
],
|
118
|
+
ids=[
|
119
|
+
"invalid_json_list",
|
120
|
+
"invalid_json_model",
|
121
|
+
"invalid_json_json_type",
|
122
|
+
"non_list_for_list",
|
123
|
+
"unsupported_type",
|
124
|
+
],
|
125
|
+
)
|
126
|
+
def test_parse_type_from_str_error_cases(input_str, output_type, expected_exception, expected_code, expected_message):
|
127
|
+
with pytest.raises(expected_exception) as excinfo:
|
128
|
+
parse_type_from_str(input_str, output_type)
|
129
|
+
|
130
|
+
if expected_code:
|
131
|
+
assert excinfo.value.code == expected_code
|
132
|
+
|
133
|
+
assert expected_message in str(excinfo.value)
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -95,10 +95,18 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
|
|
95
95
|
return bool(result_as_str)
|
96
96
|
|
97
97
|
if get_origin(output_type) is list:
|
98
|
+
# Handle empty string case for list types by returning an empty list
|
99
|
+
if not result_as_str.strip():
|
100
|
+
return []
|
101
|
+
|
98
102
|
try:
|
99
103
|
data = json.loads(result_as_str)
|
100
104
|
except json.JSONDecodeError:
|
101
|
-
raise ValueError("Invalid JSON Array format for result_as_str")
|
105
|
+
# raise ValueError("Invalid JSON Array format for result_as_str")
|
106
|
+
raise NodeException(
|
107
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
108
|
+
message="Invalid JSON Array format for result_as_str",
|
109
|
+
)
|
102
110
|
|
103
111
|
if not isinstance(data, list):
|
104
112
|
raise ValueError(f"Expected a list of items for result_as_str, received {data.__class__.__name__}")
|
@@ -117,7 +125,10 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
|
|
117
125
|
return data["value"]
|
118
126
|
return data
|
119
127
|
except json.JSONDecodeError:
|
120
|
-
raise
|
128
|
+
raise NodeException(
|
129
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
130
|
+
message="Invalid JSON format for result_as_str",
|
131
|
+
)
|
121
132
|
|
122
133
|
if get_origin(output_type) is Union:
|
123
134
|
for inner_type in get_args(output_type):
|
@@ -140,7 +151,10 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
|
|
140
151
|
data = data["value"]
|
141
152
|
return output_type.model_validate(data)
|
142
153
|
except json.JSONDecodeError:
|
143
|
-
raise
|
154
|
+
raise NodeException(
|
155
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
156
|
+
message="Invalid JSON format for result_as_str",
|
157
|
+
)
|
144
158
|
|
145
159
|
raise ValueError(f"Unsupported output type: {output_type}")
|
146
160
|
|
vellum/workflows/outputs/base.py
CHANGED
@@ -201,7 +201,7 @@ class BaseOutputs(metaclass=_BaseOutputsMeta):
|
|
201
201
|
self._outputs_post_init(**kwargs)
|
202
202
|
|
203
203
|
def __eq__(self, other: object) -> bool:
|
204
|
-
if not isinstance(other, dict):
|
204
|
+
if not isinstance(other, (dict, BaseOutputs)):
|
205
205
|
return super().__eq__(other)
|
206
206
|
|
207
207
|
outputs = {ref.name: value for ref, value in self if value is not undefined}
|
@@ -200,7 +200,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
200
200
|
parent=parent_context,
|
201
201
|
)
|
202
202
|
node_run_response: NodeRunResponse
|
203
|
-
was_mocked =
|
203
|
+
was_mocked: Optional[bool] = None
|
204
204
|
mock_candidates = self.workflow.context.node_output_mocks_map.get(node.Outputs) or []
|
205
205
|
for mock_candidate in mock_candidates:
|
206
206
|
if mock_candidate.when_condition.resolve(node.state):
|
@@ -315,6 +315,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
315
315
|
node_definition=node.__class__,
|
316
316
|
outputs=outputs,
|
317
317
|
invoked_ports=invoked_ports,
|
318
|
+
mocked=was_mocked,
|
318
319
|
),
|
319
320
|
parent=parent_context,
|
320
321
|
)
|