vellum-ai 0.14.8__py3-none-any.whl → 0.14.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/utils/templating/render.py +4 -1
- vellum/workflows/descriptors/base.py +6 -0
- vellum/workflows/descriptors/tests/test_utils.py +14 -0
- vellum/workflows/events/tests/test_event.py +40 -0
- vellum/workflows/events/workflow.py +21 -1
- vellum/workflows/expressions/greater_than.py +15 -8
- vellum/workflows/expressions/greater_than_or_equal_to.py +14 -8
- vellum/workflows/expressions/less_than.py +14 -8
- vellum/workflows/expressions/less_than_or_equal_to.py +14 -8
- vellum/workflows/expressions/parse_json.py +30 -0
- vellum/workflows/expressions/tests/__init__.py +0 -0
- vellum/workflows/expressions/tests/test_expressions.py +310 -0
- vellum/workflows/expressions/tests/test_parse_json.py +31 -0
- vellum/workflows/nodes/bases/base.py +5 -2
- vellum/workflows/nodes/core/templating_node/node.py +0 -1
- vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +49 -0
- vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +6 -7
- vellum/workflows/nodes/displayable/code_execution_node/node.py +18 -8
- vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +53 -0
- vellum/workflows/nodes/utils.py +14 -3
- vellum/workflows/runner/runner.py +34 -9
- vellum/workflows/state/encoder.py +2 -1
- {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/RECORD +35 -31
- vellum_ee/workflows/display/nodes/base_node_display.py +4 -0
- vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +31 -0
- vellum_ee/workflows/display/types.py +1 -14
- vellum_ee/workflows/display/workflows/base_workflow_display.py +38 -18
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +27 -0
- vellum_ee/workflows/tests/test_display_meta.py +2 -0
- vellum_ee/workflows/tests/test_server.py +1 -0
- {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,310 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.expressions.greater_than import GreaterThanExpression
|
4
|
+
from vellum.workflows.expressions.greater_than_or_equal_to import GreaterThanOrEqualToExpression
|
5
|
+
from vellum.workflows.expressions.less_than import LessThanExpression
|
6
|
+
from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
|
9
|
+
|
10
|
+
class Comparable:
|
11
|
+
"""A custom class with two values, where comparisons use a computed metric (multiplication)."""
|
12
|
+
|
13
|
+
def __init__(self, value1, value2):
|
14
|
+
self.value1 = value1 # First numerical value
|
15
|
+
self.value2 = value2 # Second numerical value
|
16
|
+
|
17
|
+
def computed_value(self):
|
18
|
+
return self.value1 * self.value2 # Multiply for comparison
|
19
|
+
|
20
|
+
def __ge__(self, other):
|
21
|
+
if isinstance(other, Comparable):
|
22
|
+
return self.computed_value() >= other.computed_value()
|
23
|
+
elif isinstance(other, (int, float)):
|
24
|
+
return self.computed_value() >= other
|
25
|
+
return NotImplemented
|
26
|
+
|
27
|
+
def __gt__(self, other):
|
28
|
+
if isinstance(other, Comparable):
|
29
|
+
return self.computed_value() > other.computed_value()
|
30
|
+
elif isinstance(other, (int, float)):
|
31
|
+
return self.computed_value() > other
|
32
|
+
return NotImplemented
|
33
|
+
|
34
|
+
def __le__(self, other):
|
35
|
+
if isinstance(other, Comparable):
|
36
|
+
return self.computed_value() <= other.computed_value()
|
37
|
+
elif isinstance(other, (int, float)):
|
38
|
+
return self.computed_value() <= other
|
39
|
+
return NotImplemented
|
40
|
+
|
41
|
+
def __lt__(self, other):
|
42
|
+
if isinstance(other, Comparable):
|
43
|
+
return self.computed_value() < other.computed_value()
|
44
|
+
elif isinstance(other, (int, float)):
|
45
|
+
return self.computed_value() < other
|
46
|
+
return NotImplemented
|
47
|
+
|
48
|
+
|
49
|
+
class NonComparable:
|
50
|
+
"""A custom class that does not support comparisons."""
|
51
|
+
|
52
|
+
def __init__(self, value1, value2):
|
53
|
+
self.value1 = value1
|
54
|
+
self.value2 = value2
|
55
|
+
|
56
|
+
|
57
|
+
class TestState(BaseState):
|
58
|
+
pass
|
59
|
+
|
60
|
+
|
61
|
+
def test_greater_than_or_equal_to():
|
62
|
+
# GIVEN objects with two values
|
63
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
64
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
65
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
66
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
67
|
+
|
68
|
+
state = TestState()
|
69
|
+
|
70
|
+
# WHEN comparing objects
|
71
|
+
assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 >= 20
|
72
|
+
assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 >= 18
|
73
|
+
assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
|
74
|
+
|
75
|
+
# WHEN comparing to raw numbers
|
76
|
+
assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 >= 19
|
77
|
+
assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
|
78
|
+
|
79
|
+
|
80
|
+
def test_greater_than_or_equal_to_invalid():
|
81
|
+
# GIVEN objects with two values
|
82
|
+
obj1 = Comparable(4, 5)
|
83
|
+
obj2 = "invalid"
|
84
|
+
|
85
|
+
state = TestState()
|
86
|
+
|
87
|
+
# WHEN comparing objects with incompatible types
|
88
|
+
with pytest.raises(TypeError) as exc_info:
|
89
|
+
GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
90
|
+
|
91
|
+
# THEN the expected error is raised
|
92
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'str'"
|
93
|
+
|
94
|
+
# WHEN comparing objects with incompatible types
|
95
|
+
with pytest.raises(TypeError) as exc_info:
|
96
|
+
GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
97
|
+
|
98
|
+
# THEN the expected error is raised
|
99
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'str' and 'Comparable'"
|
100
|
+
|
101
|
+
|
102
|
+
def test_greater_than_or_equal_to_non_comparable():
|
103
|
+
# GIVEN objects with two values
|
104
|
+
obj1 = Comparable(4, 5)
|
105
|
+
obj2 = NonComparable(2, 10)
|
106
|
+
|
107
|
+
state = TestState()
|
108
|
+
|
109
|
+
# WHEN comparing objects with incompatible types
|
110
|
+
with pytest.raises(TypeError) as exc_info:
|
111
|
+
GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
112
|
+
|
113
|
+
# THEN the expected error is raised
|
114
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'NonComparable'"
|
115
|
+
|
116
|
+
# WHEN comparing objects with incompatible types
|
117
|
+
with pytest.raises(TypeError) as exc_info:
|
118
|
+
GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
119
|
+
|
120
|
+
# THEN the expected error is raised
|
121
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'NonComparable' and 'Comparable'"
|
122
|
+
|
123
|
+
|
124
|
+
def test_greater_than():
|
125
|
+
# GIVEN objects with two values
|
126
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
127
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
128
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
129
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
130
|
+
|
131
|
+
state = TestState()
|
132
|
+
|
133
|
+
# WHEN comparing objects
|
134
|
+
assert GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 > 20
|
135
|
+
assert GreaterThanExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 > 18
|
136
|
+
assert GreaterThanExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
|
137
|
+
|
138
|
+
# WHEN comparing to raw numbers
|
139
|
+
assert GreaterThanExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 > 19
|
140
|
+
assert GreaterThanExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
|
141
|
+
|
142
|
+
|
143
|
+
def test_greater_than_invalid():
|
144
|
+
# GIVEN objects with two values
|
145
|
+
obj1 = Comparable(4, 5)
|
146
|
+
obj2 = "invalid"
|
147
|
+
|
148
|
+
state = TestState()
|
149
|
+
|
150
|
+
# WHEN comparing objects with incompatible types
|
151
|
+
with pytest.raises(TypeError) as exc_info:
|
152
|
+
GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
153
|
+
|
154
|
+
# THEN the expected error is raised
|
155
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'str'"
|
156
|
+
|
157
|
+
# WHEN comparing objects with incompatible types
|
158
|
+
with pytest.raises(TypeError) as exc_info:
|
159
|
+
GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
160
|
+
|
161
|
+
# THEN the expected error is raised
|
162
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'str' and 'Comparable'"
|
163
|
+
|
164
|
+
|
165
|
+
def test_greater_than_non_comparable():
|
166
|
+
# GIVEN objects with two values
|
167
|
+
obj1 = Comparable(4, 5)
|
168
|
+
obj2 = NonComparable(2, 10)
|
169
|
+
|
170
|
+
state = TestState()
|
171
|
+
|
172
|
+
# WHEN comparing objects with incompatible types
|
173
|
+
with pytest.raises(TypeError) as exc_info:
|
174
|
+
GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
175
|
+
|
176
|
+
# THEN the expected error is raised
|
177
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'NonComparable'"
|
178
|
+
|
179
|
+
# WHEN comparing objects with incompatible types
|
180
|
+
with pytest.raises(TypeError) as exc_info:
|
181
|
+
GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
182
|
+
|
183
|
+
# THEN the expected error is raised
|
184
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'NonComparable' and 'Comparable'"
|
185
|
+
|
186
|
+
|
187
|
+
def test_less_than_or_equal_to():
|
188
|
+
# GIVEN objects with two values
|
189
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
190
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
191
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
192
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
193
|
+
|
194
|
+
state = TestState()
|
195
|
+
|
196
|
+
# WHEN comparing objects
|
197
|
+
assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 <= 20
|
198
|
+
assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 > 18
|
199
|
+
assert LessThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 <= 25
|
200
|
+
|
201
|
+
# WHEN comparing to raw numbers
|
202
|
+
assert LessThanOrEqualToExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 <= 21
|
203
|
+
assert LessThanOrEqualToExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
|
204
|
+
|
205
|
+
|
206
|
+
def test_less_than_or_equal_to_invalid():
|
207
|
+
# GIVEN objects with two values
|
208
|
+
obj1 = Comparable(4, 5)
|
209
|
+
obj2 = "invalid"
|
210
|
+
|
211
|
+
state = TestState()
|
212
|
+
|
213
|
+
# WHEN comparing objects with incompatible types
|
214
|
+
with pytest.raises(TypeError) as exc_info:
|
215
|
+
LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
216
|
+
|
217
|
+
# THEN the expected error is raised
|
218
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'str'"
|
219
|
+
|
220
|
+
# WHEN comparing objects with incompatible types
|
221
|
+
with pytest.raises(TypeError) as exc_info:
|
222
|
+
LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
223
|
+
|
224
|
+
# THEN the expected error is raised
|
225
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'str' and 'Comparable'"
|
226
|
+
|
227
|
+
|
228
|
+
def test_less_than_or_equal_to_non_comparable():
|
229
|
+
# GIVEN objects with two values
|
230
|
+
obj1 = Comparable(4, 5)
|
231
|
+
obj2 = NonComparable(2, 10)
|
232
|
+
|
233
|
+
state = TestState()
|
234
|
+
|
235
|
+
# WHEN comparing objects with incompatible types
|
236
|
+
with pytest.raises(TypeError) as exc_info:
|
237
|
+
LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
238
|
+
|
239
|
+
# THEN the expected error is raised
|
240
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'NonComparable'"
|
241
|
+
|
242
|
+
# WHEN comparing objects with incompatible types
|
243
|
+
with pytest.raises(TypeError) as exc_info:
|
244
|
+
LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
245
|
+
|
246
|
+
# THEN the expected error is raised
|
247
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'NonComparable' and 'Comparable'"
|
248
|
+
|
249
|
+
|
250
|
+
def test_less_than():
|
251
|
+
# GIVEN objects with two values
|
252
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
253
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
254
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
255
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
256
|
+
|
257
|
+
state = TestState()
|
258
|
+
|
259
|
+
# WHEN comparing objects
|
260
|
+
assert LessThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 < 20
|
261
|
+
assert LessThanExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 < 18
|
262
|
+
assert LessThanExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 < 25
|
263
|
+
|
264
|
+
# WHEN comparing to raw numbers
|
265
|
+
assert LessThanExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 < 21
|
266
|
+
assert LessThanExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
|
267
|
+
|
268
|
+
|
269
|
+
def test_less_than_invalid():
|
270
|
+
# GIVEN objects with two values
|
271
|
+
obj1 = Comparable(4, 5)
|
272
|
+
obj2 = "invalid"
|
273
|
+
|
274
|
+
state = TestState()
|
275
|
+
|
276
|
+
# WHEN comparing objects with incompatible types
|
277
|
+
with pytest.raises(TypeError) as exc_info:
|
278
|
+
LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
279
|
+
|
280
|
+
# THEN the expected error is raised
|
281
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'str'"
|
282
|
+
|
283
|
+
# WHEN comparing objects with incompatible types
|
284
|
+
with pytest.raises(TypeError) as exc_info:
|
285
|
+
LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
286
|
+
|
287
|
+
# THEN the expected error is raised
|
288
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'str' and 'Comparable'"
|
289
|
+
|
290
|
+
|
291
|
+
def test_less_than_non_comparable():
|
292
|
+
# GIVEN objects with two values
|
293
|
+
obj1 = Comparable(4, 5)
|
294
|
+
obj2 = NonComparable(2, 10)
|
295
|
+
|
296
|
+
state = TestState()
|
297
|
+
|
298
|
+
# WHEN comparing objects with incompatible types
|
299
|
+
with pytest.raises(TypeError) as exc_info:
|
300
|
+
LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
301
|
+
|
302
|
+
# THEN the expected error is raised
|
303
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'NonComparable'"
|
304
|
+
|
305
|
+
# WHEN comparing objects with incompatible types
|
306
|
+
with pytest.raises(TypeError) as exc_info:
|
307
|
+
LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
308
|
+
|
309
|
+
# THEN the expected error is raised
|
310
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'NonComparable' and 'Comparable'"
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
4
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
5
|
+
from vellum.workflows.state.base import BaseState
|
6
|
+
|
7
|
+
|
8
|
+
def test_parse_json_invalid_json():
|
9
|
+
# GIVEN an invalid JSON string
|
10
|
+
state = BaseState()
|
11
|
+
expression = ConstantValueReference('{"key": value}').parse_json()
|
12
|
+
|
13
|
+
# WHEN we attempt to resolve the expression
|
14
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
15
|
+
expression.resolve(state)
|
16
|
+
|
17
|
+
# THEN an exception should be raised
|
18
|
+
assert "Failed to parse JSON" in str(exc_info.value)
|
19
|
+
|
20
|
+
|
21
|
+
def test_parse_json_invalid_type():
|
22
|
+
# GIVEN a non-string value
|
23
|
+
state = BaseState()
|
24
|
+
expression = ConstantValueReference(123).parse_json()
|
25
|
+
|
26
|
+
# WHEN we attempt to resolve the expression
|
27
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
28
|
+
expression.resolve(state)
|
29
|
+
|
30
|
+
# THEN an exception should be raised
|
31
|
+
assert "Expected a string, but got 123 of type <class 'int'>" == str(exc_info.value)
|
@@ -318,9 +318,12 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
318
318
|
original_base = get_original_base(self.__class__)
|
319
319
|
|
320
320
|
args = get_args(original_base)
|
321
|
-
state_type = args[0]
|
322
321
|
|
323
|
-
if
|
322
|
+
if args and len(args) > 0:
|
323
|
+
state_type = args[0]
|
324
|
+
if isinstance(state_type, TypeVar):
|
325
|
+
state_type = BaseState
|
326
|
+
else:
|
324
327
|
state_type = BaseState
|
325
328
|
|
326
329
|
self.state = state_type()
|
@@ -82,7 +82,6 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
|
|
82
82
|
def run(self) -> Outputs:
|
83
83
|
rendered_template = self._render_template()
|
84
84
|
result = self._cast_rendered_template(rendered_template)
|
85
|
-
|
86
85
|
return self.Outputs(result=result)
|
87
86
|
|
88
87
|
def _render_template(self) -> str:
|
@@ -234,3 +234,52 @@ def test_templating_node__last_chat_message():
|
|
234
234
|
|
235
235
|
# THEN the output is the expected JSON
|
236
236
|
assert outputs.result == [ChatMessage(role="USER", text="Hello")]
|
237
|
+
|
238
|
+
|
239
|
+
def test_templating_node__function_call_value_input():
|
240
|
+
# GIVEN a templating node that receives a FunctionCallVellumValue
|
241
|
+
class FunctionCallTemplateNode(TemplatingNode[BaseState, FunctionCall]):
|
242
|
+
template = """{{ function_call }}"""
|
243
|
+
inputs = {
|
244
|
+
"function_call": FunctionCallVellumValue(
|
245
|
+
type="FUNCTION_CALL",
|
246
|
+
value=FunctionCall(name="test_function", arguments={"key": "value"}, id="test_id", state="FULFILLED"),
|
247
|
+
)
|
248
|
+
}
|
249
|
+
|
250
|
+
# WHEN the node is run
|
251
|
+
node = FunctionCallTemplateNode()
|
252
|
+
outputs = node.run()
|
253
|
+
|
254
|
+
# THEN the output is the expected function call
|
255
|
+
assert outputs.result == FunctionCall(
|
256
|
+
name="test_function", arguments={"key": "value"}, id="test_id", state="FULFILLED"
|
257
|
+
)
|
258
|
+
|
259
|
+
|
260
|
+
def test_templating_node__function_call_as_json():
|
261
|
+
# GIVEN a node that receives a FunctionCallVellumValue but outputs as Json
|
262
|
+
class JsonOutputNode(TemplatingNode[BaseState, Json]):
|
263
|
+
template = """{{ function_call }}"""
|
264
|
+
inputs = {
|
265
|
+
"function_call": FunctionCallVellumValue(
|
266
|
+
type="FUNCTION_CALL",
|
267
|
+
value=FunctionCall(name="test_function", arguments={"key": "value"}, id="test_id", state="FULFILLED"),
|
268
|
+
)
|
269
|
+
}
|
270
|
+
|
271
|
+
# WHEN the node is run
|
272
|
+
node = JsonOutputNode()
|
273
|
+
outputs = node.run()
|
274
|
+
|
275
|
+
# THEN we get just the FunctionCall data as JSON
|
276
|
+
assert outputs.result == {
|
277
|
+
"name": "test_function",
|
278
|
+
"arguments": {"key": "value"},
|
279
|
+
"id": "test_id",
|
280
|
+
"state": "FULFILLED",
|
281
|
+
}
|
282
|
+
|
283
|
+
# AND we can access fields directly
|
284
|
+
assert outputs.result["arguments"] == {"key": "value"}
|
285
|
+
assert outputs.result["name"] == "test_function"
|
@@ -5,7 +5,6 @@ from vellum.client.core.api_error import ApiError
|
|
5
5
|
from vellum.workflows.constants import APIRequestMethod, AuthorizationType
|
6
6
|
from vellum.workflows.exceptions import NodeException
|
7
7
|
from vellum.workflows.nodes import APINode
|
8
|
-
from vellum.workflows.state import BaseState
|
9
8
|
from vellum.workflows.types.core import VellumSecret
|
10
9
|
|
11
10
|
|
@@ -29,7 +28,7 @@ def test_run_workflow__secrets(vellum_client):
|
|
29
28
|
}
|
30
29
|
bearer_token_value = VellumSecret(name="secret")
|
31
30
|
|
32
|
-
node = SimpleBaseAPINode(
|
31
|
+
node = SimpleBaseAPINode()
|
33
32
|
terminal = node.run()
|
34
33
|
|
35
34
|
assert vellum_client.execute_api.call_count == 1
|
@@ -39,7 +38,7 @@ def test_run_workflow__secrets(vellum_client):
|
|
39
38
|
|
40
39
|
|
41
40
|
def test_api_node_raises_error_when_api_call_fails(vellum_client):
|
42
|
-
#
|
41
|
+
# GIVEN an API call that fails
|
43
42
|
vellum_client.execute_api.side_effect = ApiError(status_code=400, body="API Error")
|
44
43
|
|
45
44
|
class SimpleAPINode(APINode):
|
@@ -54,14 +53,14 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
|
|
54
53
|
}
|
55
54
|
bearer_token_value = VellumSecret(name="api_key")
|
56
55
|
|
57
|
-
node = SimpleAPINode(
|
56
|
+
node = SimpleAPINode()
|
58
57
|
|
59
|
-
#
|
58
|
+
# WHEN we run the node
|
60
59
|
with pytest.raises(NodeException) as excinfo:
|
61
60
|
node.run()
|
62
61
|
|
63
|
-
#
|
62
|
+
# THEN an exception should be raised
|
64
63
|
assert "Failed to prepare HTTP request" in str(excinfo.value)
|
65
64
|
|
66
|
-
#
|
65
|
+
# AND the API call should have been made
|
67
66
|
assert vellum_client.execute_api.call_count == 1
|
@@ -19,6 +19,7 @@ from vellum import (
|
|
19
19
|
VellumError,
|
20
20
|
VellumValue,
|
21
21
|
)
|
22
|
+
from vellum.client.core.api_error import ApiError
|
22
23
|
from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
|
23
24
|
from vellum.core import RequestOptions
|
24
25
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -103,14 +104,23 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
|
|
103
104
|
input_values = self._compile_code_inputs()
|
104
105
|
expected_output_type = primitive_type_to_vellum_variable_type(output_type)
|
105
106
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
107
|
+
try:
|
108
|
+
code_execution_result = self._context.vellum_client.execute_code(
|
109
|
+
input_values=input_values,
|
110
|
+
code=code,
|
111
|
+
runtime=self.runtime,
|
112
|
+
output_type=expected_output_type,
|
113
|
+
packages=self.packages or [],
|
114
|
+
request_options=self.request_options,
|
115
|
+
)
|
116
|
+
except ApiError as e:
|
117
|
+
if e.status_code == 400 and isinstance(e.body, dict) and "message" in e.body:
|
118
|
+
raise NodeException(
|
119
|
+
message=e.body["message"],
|
120
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
121
|
+
)
|
122
|
+
|
123
|
+
raise
|
114
124
|
|
115
125
|
if code_execution_result.output.type != expected_output_type:
|
116
126
|
actual_type = code_execution_result.output.type
|
@@ -5,6 +5,7 @@ from typing import Any, List, Union
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
7
|
from vellum import CodeExecutorResponse, NumberVellumValue, StringInput, StringVellumValue
|
8
|
+
from vellum.client.errors.bad_request_error import BadRequestError
|
8
9
|
from vellum.client.types.chat_message import ChatMessage
|
9
10
|
from vellum.client.types.code_execution_package import CodeExecutionPackage
|
10
11
|
from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
|
@@ -17,6 +18,7 @@ from vellum.workflows.inputs.base import BaseInputs
|
|
17
18
|
from vellum.workflows.nodes.displayable.code_execution_node import CodeExecutionNode
|
18
19
|
from vellum.workflows.references.vellum_secret import VellumSecretReference
|
19
20
|
from vellum.workflows.state.base import BaseState, StateMeta
|
21
|
+
from vellum.workflows.types.core import Json
|
20
22
|
|
21
23
|
|
22
24
|
def test_run_node__happy_path(vellum_client):
|
@@ -690,3 +692,54 @@ def main():
|
|
690
692
|
"result": [ChatMessage(role="USER", content=StringChatMessageContent(value="Hello, world!"))],
|
691
693
|
"log": "",
|
692
694
|
}
|
695
|
+
|
696
|
+
|
697
|
+
def test_run_node__execute_code_api_fails__node_exception(vellum_client):
|
698
|
+
# GIVEN a node that will throw a JSON.parse error
|
699
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, Json]):
|
700
|
+
code = """\
|
701
|
+
async function main(inputs: {
|
702
|
+
data: string,
|
703
|
+
}): Promise<string> {
|
704
|
+
return JSON.parse(inputs.data)
|
705
|
+
}
|
706
|
+
"""
|
707
|
+
code_inputs = {
|
708
|
+
"data": "not a valid json string",
|
709
|
+
}
|
710
|
+
runtime = "TYPESCRIPT_5_3_3"
|
711
|
+
|
712
|
+
# AND the execute_code API will fail
|
713
|
+
message = """\
|
714
|
+
Code execution error (exit code 1): undefined:1
|
715
|
+
not a valid json string
|
716
|
+
^
|
717
|
+
|
718
|
+
SyntaxError: Unexpected token 'o', \"not a valid\"... is not valid JSON
|
719
|
+
at JSON.parse (<anonymous>)
|
720
|
+
at Object.eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:40:40)
|
721
|
+
at step (eval at execute (/workdir/runner.js:16:18), <anonymous>:32:23)
|
722
|
+
at Object.eval [as next] (eval at execute (/workdir/runner.js:16:18), <anonymous>:13:53)
|
723
|
+
at eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:7:71)
|
724
|
+
at new Promise (<anonymous>)
|
725
|
+
at __awaiter (eval at execute (/workdir/runner.js:16:18), <anonymous>:3:12)
|
726
|
+
at Object.main (eval at execute (/workdir/runner.js:16:18), <anonymous>:38:12)
|
727
|
+
at execute (/workdir/runner.js:17:33)
|
728
|
+
at Interface.<anonymous> (/workdir/runner.js:58:5)
|
729
|
+
|
730
|
+
Node.js v21.7.3
|
731
|
+
"""
|
732
|
+
vellum_client.execute_code.side_effect = BadRequestError(
|
733
|
+
body={
|
734
|
+
"message": message,
|
735
|
+
"log": "",
|
736
|
+
}
|
737
|
+
)
|
738
|
+
|
739
|
+
# WHEN we run the node
|
740
|
+
node = ExampleCodeExecutionNode()
|
741
|
+
with pytest.raises(NodeException) as exc_info:
|
742
|
+
node.run()
|
743
|
+
|
744
|
+
# AND the error should contain the execution error details
|
745
|
+
assert exc_info.value.message == message
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -109,7 +109,11 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
|
|
109
109
|
|
110
110
|
if output_type is Json:
|
111
111
|
try:
|
112
|
-
|
112
|
+
data = json.loads(result_as_str)
|
113
|
+
# If we got a FunctionCallVellumValue, return just the value
|
114
|
+
if isinstance(data, dict) and data.get("type") == "FUNCTION_CALL" and "value" in data:
|
115
|
+
return data["value"]
|
116
|
+
return data
|
113
117
|
except json.JSONDecodeError:
|
114
118
|
raise ValueError("Invalid JSON format for result_as_str")
|
115
119
|
|
@@ -124,9 +128,16 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
|
|
124
128
|
if issubclass(output_type, BaseModel):
|
125
129
|
try:
|
126
130
|
data = json.loads(result_as_str)
|
131
|
+
# If we got a FunctionCallVellumValue extract FunctionCall,
|
132
|
+
if (
|
133
|
+
hasattr(output_type, "__name__")
|
134
|
+
and output_type.__name__ == "FunctionCall"
|
135
|
+
and isinstance(data, dict)
|
136
|
+
and "value" in data
|
137
|
+
):
|
138
|
+
data = data["value"]
|
139
|
+
return output_type.model_validate(data)
|
127
140
|
except json.JSONDecodeError:
|
128
141
|
raise ValueError("Invalid JSON format for result_as_str")
|
129
142
|
|
130
|
-
return output_type.model_validate(data)
|
131
|
-
|
132
143
|
raise ValueError(f"Unsupported output type: {output_type}")
|