vellum-ai 0.14.7__py3-none-any.whl → 0.14.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +2 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +2 -0
- vellum/client/types/document_prompt_block.py +29 -0
- vellum/client/types/prompt_block.py +2 -0
- vellum/types/document_prompt_block.py +3 -0
- vellum/workflows/descriptors/base.py +6 -0
- vellum/workflows/descriptors/tests/test_utils.py +14 -0
- vellum/workflows/events/tests/test_event.py +40 -0
- vellum/workflows/events/workflow.py +20 -1
- vellum/workflows/expressions/greater_than.py +15 -8
- vellum/workflows/expressions/greater_than_or_equal_to.py +14 -8
- vellum/workflows/expressions/less_than.py +14 -8
- vellum/workflows/expressions/less_than_or_equal_to.py +14 -8
- vellum/workflows/expressions/parse_json.py +30 -0
- vellum/workflows/expressions/tests/__init__.py +0 -0
- vellum/workflows/expressions/tests/test_expressions.py +310 -0
- vellum/workflows/expressions/tests/test_parse_json.py +31 -0
- vellum/workflows/nodes/bases/base.py +5 -2
- vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +34 -2
- vellum/workflows/nodes/displayable/bases/api_node/node.py +1 -1
- vellum/workflows/nodes/displayable/code_execution_node/node.py +18 -8
- vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +53 -0
- vellum/workflows/runner/runner.py +33 -4
- vellum/workflows/state/encoder.py +2 -1
- {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/RECORD +44 -38
- vellum_cli/__init__.py +9 -2
- vellum_cli/config.py +1 -0
- vellum_cli/init.py +6 -2
- vellum_cli/pull.py +1 -0
- vellum_cli/tests/test_init.py +194 -76
- vellum_cli/tests/test_pull.py +8 -0
- vellum_cli/tests/test_push.py +1 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +4 -0
- vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +114 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +118 -3
- vellum_ee/workflows/display/types.py +1 -14
- vellum_ee/workflows/display/workflows/base_workflow_display.py +48 -19
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +12 -0
- vellum_ee/workflows/tests/test_server.py +1 -0
- {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,310 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.expressions.greater_than import GreaterThanExpression
|
4
|
+
from vellum.workflows.expressions.greater_than_or_equal_to import GreaterThanOrEqualToExpression
|
5
|
+
from vellum.workflows.expressions.less_than import LessThanExpression
|
6
|
+
from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
|
9
|
+
|
10
|
+
class Comparable:
|
11
|
+
"""A custom class with two values, where comparisons use a computed metric (multiplication)."""
|
12
|
+
|
13
|
+
def __init__(self, value1, value2):
|
14
|
+
self.value1 = value1 # First numerical value
|
15
|
+
self.value2 = value2 # Second numerical value
|
16
|
+
|
17
|
+
def computed_value(self):
|
18
|
+
return self.value1 * self.value2 # Multiply for comparison
|
19
|
+
|
20
|
+
def __ge__(self, other):
|
21
|
+
if isinstance(other, Comparable):
|
22
|
+
return self.computed_value() >= other.computed_value()
|
23
|
+
elif isinstance(other, (int, float)):
|
24
|
+
return self.computed_value() >= other
|
25
|
+
return NotImplemented
|
26
|
+
|
27
|
+
def __gt__(self, other):
|
28
|
+
if isinstance(other, Comparable):
|
29
|
+
return self.computed_value() > other.computed_value()
|
30
|
+
elif isinstance(other, (int, float)):
|
31
|
+
return self.computed_value() > other
|
32
|
+
return NotImplemented
|
33
|
+
|
34
|
+
def __le__(self, other):
|
35
|
+
if isinstance(other, Comparable):
|
36
|
+
return self.computed_value() <= other.computed_value()
|
37
|
+
elif isinstance(other, (int, float)):
|
38
|
+
return self.computed_value() <= other
|
39
|
+
return NotImplemented
|
40
|
+
|
41
|
+
def __lt__(self, other):
|
42
|
+
if isinstance(other, Comparable):
|
43
|
+
return self.computed_value() < other.computed_value()
|
44
|
+
elif isinstance(other, (int, float)):
|
45
|
+
return self.computed_value() < other
|
46
|
+
return NotImplemented
|
47
|
+
|
48
|
+
|
49
|
+
class NonComparable:
|
50
|
+
"""A custom class that does not support comparisons."""
|
51
|
+
|
52
|
+
def __init__(self, value1, value2):
|
53
|
+
self.value1 = value1
|
54
|
+
self.value2 = value2
|
55
|
+
|
56
|
+
|
57
|
+
class TestState(BaseState):
|
58
|
+
pass
|
59
|
+
|
60
|
+
|
61
|
+
def test_greater_than_or_equal_to():
|
62
|
+
# GIVEN objects with two values
|
63
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
64
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
65
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
66
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
67
|
+
|
68
|
+
state = TestState()
|
69
|
+
|
70
|
+
# WHEN comparing objects
|
71
|
+
assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 >= 20
|
72
|
+
assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 >= 18
|
73
|
+
assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
|
74
|
+
|
75
|
+
# WHEN comparing to raw numbers
|
76
|
+
assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 >= 19
|
77
|
+
assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
|
78
|
+
|
79
|
+
|
80
|
+
def test_greater_than_or_equal_to_invalid():
|
81
|
+
# GIVEN objects with two values
|
82
|
+
obj1 = Comparable(4, 5)
|
83
|
+
obj2 = "invalid"
|
84
|
+
|
85
|
+
state = TestState()
|
86
|
+
|
87
|
+
# WHEN comparing objects with incompatible types
|
88
|
+
with pytest.raises(TypeError) as exc_info:
|
89
|
+
GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
90
|
+
|
91
|
+
# THEN the expected error is raised
|
92
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'str'"
|
93
|
+
|
94
|
+
# WHEN comparing objects with incompatible types
|
95
|
+
with pytest.raises(TypeError) as exc_info:
|
96
|
+
GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
97
|
+
|
98
|
+
# THEN the expected error is raised
|
99
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'str' and 'Comparable'"
|
100
|
+
|
101
|
+
|
102
|
+
def test_greater_than_or_equal_to_non_comparable():
|
103
|
+
# GIVEN objects with two values
|
104
|
+
obj1 = Comparable(4, 5)
|
105
|
+
obj2 = NonComparable(2, 10)
|
106
|
+
|
107
|
+
state = TestState()
|
108
|
+
|
109
|
+
# WHEN comparing objects with incompatible types
|
110
|
+
with pytest.raises(TypeError) as exc_info:
|
111
|
+
GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
112
|
+
|
113
|
+
# THEN the expected error is raised
|
114
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'NonComparable'"
|
115
|
+
|
116
|
+
# WHEN comparing objects with incompatible types
|
117
|
+
with pytest.raises(TypeError) as exc_info:
|
118
|
+
GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
119
|
+
|
120
|
+
# THEN the expected error is raised
|
121
|
+
assert str(exc_info.value) == "'>=' not supported between instances of 'NonComparable' and 'Comparable'"
|
122
|
+
|
123
|
+
|
124
|
+
def test_greater_than():
|
125
|
+
# GIVEN objects with two values
|
126
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
127
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
128
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
129
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
130
|
+
|
131
|
+
state = TestState()
|
132
|
+
|
133
|
+
# WHEN comparing objects
|
134
|
+
assert GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 > 20
|
135
|
+
assert GreaterThanExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 > 18
|
136
|
+
assert GreaterThanExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
|
137
|
+
|
138
|
+
# WHEN comparing to raw numbers
|
139
|
+
assert GreaterThanExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 > 19
|
140
|
+
assert GreaterThanExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
|
141
|
+
|
142
|
+
|
143
|
+
def test_greater_than_invalid():
|
144
|
+
# GIVEN objects with two values
|
145
|
+
obj1 = Comparable(4, 5)
|
146
|
+
obj2 = "invalid"
|
147
|
+
|
148
|
+
state = TestState()
|
149
|
+
|
150
|
+
# WHEN comparing objects with incompatible types
|
151
|
+
with pytest.raises(TypeError) as exc_info:
|
152
|
+
GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
153
|
+
|
154
|
+
# THEN the expected error is raised
|
155
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'str'"
|
156
|
+
|
157
|
+
# WHEN comparing objects with incompatible types
|
158
|
+
with pytest.raises(TypeError) as exc_info:
|
159
|
+
GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
160
|
+
|
161
|
+
# THEN the expected error is raised
|
162
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'str' and 'Comparable'"
|
163
|
+
|
164
|
+
|
165
|
+
def test_greater_than_non_comparable():
|
166
|
+
# GIVEN objects with two values
|
167
|
+
obj1 = Comparable(4, 5)
|
168
|
+
obj2 = NonComparable(2, 10)
|
169
|
+
|
170
|
+
state = TestState()
|
171
|
+
|
172
|
+
# WHEN comparing objects with incompatible types
|
173
|
+
with pytest.raises(TypeError) as exc_info:
|
174
|
+
GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
175
|
+
|
176
|
+
# THEN the expected error is raised
|
177
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'NonComparable'"
|
178
|
+
|
179
|
+
# WHEN comparing objects with incompatible types
|
180
|
+
with pytest.raises(TypeError) as exc_info:
|
181
|
+
GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
182
|
+
|
183
|
+
# THEN the expected error is raised
|
184
|
+
assert str(exc_info.value) == "'>' not supported between instances of 'NonComparable' and 'Comparable'"
|
185
|
+
|
186
|
+
|
187
|
+
def test_less_than_or_equal_to():
|
188
|
+
# GIVEN objects with two values
|
189
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
190
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
191
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
192
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
193
|
+
|
194
|
+
state = TestState()
|
195
|
+
|
196
|
+
# WHEN comparing objects
|
197
|
+
assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 <= 20
|
198
|
+
assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 > 18
|
199
|
+
assert LessThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 <= 25
|
200
|
+
|
201
|
+
# WHEN comparing to raw numbers
|
202
|
+
assert LessThanOrEqualToExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 <= 21
|
203
|
+
assert LessThanOrEqualToExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
|
204
|
+
|
205
|
+
|
206
|
+
def test_less_than_or_equal_to_invalid():
|
207
|
+
# GIVEN objects with two values
|
208
|
+
obj1 = Comparable(4, 5)
|
209
|
+
obj2 = "invalid"
|
210
|
+
|
211
|
+
state = TestState()
|
212
|
+
|
213
|
+
# WHEN comparing objects with incompatible types
|
214
|
+
with pytest.raises(TypeError) as exc_info:
|
215
|
+
LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
216
|
+
|
217
|
+
# THEN the expected error is raised
|
218
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'str'"
|
219
|
+
|
220
|
+
# WHEN comparing objects with incompatible types
|
221
|
+
with pytest.raises(TypeError) as exc_info:
|
222
|
+
LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
223
|
+
|
224
|
+
# THEN the expected error is raised
|
225
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'str' and 'Comparable'"
|
226
|
+
|
227
|
+
|
228
|
+
def test_less_than_or_equal_to_non_comparable():
|
229
|
+
# GIVEN objects with two values
|
230
|
+
obj1 = Comparable(4, 5)
|
231
|
+
obj2 = NonComparable(2, 10)
|
232
|
+
|
233
|
+
state = TestState()
|
234
|
+
|
235
|
+
# WHEN comparing objects with incompatible types
|
236
|
+
with pytest.raises(TypeError) as exc_info:
|
237
|
+
LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
|
238
|
+
|
239
|
+
# THEN the expected error is raised
|
240
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'NonComparable'"
|
241
|
+
|
242
|
+
# WHEN comparing objects with incompatible types
|
243
|
+
with pytest.raises(TypeError) as exc_info:
|
244
|
+
LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
|
245
|
+
|
246
|
+
# THEN the expected error is raised
|
247
|
+
assert str(exc_info.value) == "'<=' not supported between instances of 'NonComparable' and 'Comparable'"
|
248
|
+
|
249
|
+
|
250
|
+
def test_less_than():
|
251
|
+
# GIVEN objects with two values
|
252
|
+
obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
|
253
|
+
obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
|
254
|
+
obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
|
255
|
+
obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
|
256
|
+
|
257
|
+
state = TestState()
|
258
|
+
|
259
|
+
# WHEN comparing objects
|
260
|
+
assert LessThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 < 20
|
261
|
+
assert LessThanExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 < 18
|
262
|
+
assert LessThanExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 < 25
|
263
|
+
|
264
|
+
# WHEN comparing to raw numbers
|
265
|
+
assert LessThanExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 < 21
|
266
|
+
assert LessThanExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
|
267
|
+
|
268
|
+
|
269
|
+
def test_less_than_invalid():
|
270
|
+
# GIVEN objects with two values
|
271
|
+
obj1 = Comparable(4, 5)
|
272
|
+
obj2 = "invalid"
|
273
|
+
|
274
|
+
state = TestState()
|
275
|
+
|
276
|
+
# WHEN comparing objects with incompatible types
|
277
|
+
with pytest.raises(TypeError) as exc_info:
|
278
|
+
LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
279
|
+
|
280
|
+
# THEN the expected error is raised
|
281
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'str'"
|
282
|
+
|
283
|
+
# WHEN comparing objects with incompatible types
|
284
|
+
with pytest.raises(TypeError) as exc_info:
|
285
|
+
LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
286
|
+
|
287
|
+
# THEN the expected error is raised
|
288
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'str' and 'Comparable'"
|
289
|
+
|
290
|
+
|
291
|
+
def test_less_than_non_comparable():
|
292
|
+
# GIVEN objects with two values
|
293
|
+
obj1 = Comparable(4, 5)
|
294
|
+
obj2 = NonComparable(2, 10)
|
295
|
+
|
296
|
+
state = TestState()
|
297
|
+
|
298
|
+
# WHEN comparing objects with incompatible types
|
299
|
+
with pytest.raises(TypeError) as exc_info:
|
300
|
+
LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
|
301
|
+
|
302
|
+
# THEN the expected error is raised
|
303
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'NonComparable'"
|
304
|
+
|
305
|
+
# WHEN comparing objects with incompatible types
|
306
|
+
with pytest.raises(TypeError) as exc_info:
|
307
|
+
LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
|
308
|
+
|
309
|
+
# THEN the expected error is raised
|
310
|
+
assert str(exc_info.value) == "'<' not supported between instances of 'NonComparable' and 'Comparable'"
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
4
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
5
|
+
from vellum.workflows.state.base import BaseState
|
6
|
+
|
7
|
+
|
8
|
+
def test_parse_json_invalid_json():
|
9
|
+
# GIVEN an invalid JSON string
|
10
|
+
state = BaseState()
|
11
|
+
expression = ConstantValueReference('{"key": value}').parse_json()
|
12
|
+
|
13
|
+
# WHEN we attempt to resolve the expression
|
14
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
15
|
+
expression.resolve(state)
|
16
|
+
|
17
|
+
# THEN an exception should be raised
|
18
|
+
assert "Failed to parse JSON" in str(exc_info.value)
|
19
|
+
|
20
|
+
|
21
|
+
def test_parse_json_invalid_type():
|
22
|
+
# GIVEN a non-string value
|
23
|
+
state = BaseState()
|
24
|
+
expression = ConstantValueReference(123).parse_json()
|
25
|
+
|
26
|
+
# WHEN we attempt to resolve the expression
|
27
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
28
|
+
expression.resolve(state)
|
29
|
+
|
30
|
+
# THEN an exception should be raised
|
31
|
+
assert "Expected a string, but got 123 of type <class 'int'>" == str(exc_info.value)
|
@@ -318,9 +318,12 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
318
318
|
original_base = get_original_base(self.__class__)
|
319
319
|
|
320
320
|
args = get_args(original_base)
|
321
|
-
state_type = args[0]
|
322
321
|
|
323
|
-
if
|
322
|
+
if args and len(args) > 0:
|
323
|
+
state_type = args[0]
|
324
|
+
if isinstance(state_type, TypeVar):
|
325
|
+
state_type = BaseState
|
326
|
+
else:
|
324
327
|
state_type = BaseState
|
325
328
|
|
326
329
|
self.state = state_type()
|
@@ -1,7 +1,10 @@
|
|
1
|
+
import pytest
|
2
|
+
|
1
3
|
from vellum import ExecuteApiResponse, VellumSecret as ClientVellumSecret
|
4
|
+
from vellum.client.core.api_error import ApiError
|
2
5
|
from vellum.workflows.constants import APIRequestMethod, AuthorizationType
|
6
|
+
from vellum.workflows.exceptions import NodeException
|
3
7
|
from vellum.workflows.nodes import APINode
|
4
|
-
from vellum.workflows.state import BaseState
|
5
8
|
from vellum.workflows.types.core import VellumSecret
|
6
9
|
|
7
10
|
|
@@ -25,10 +28,39 @@ def test_run_workflow__secrets(vellum_client):
|
|
25
28
|
}
|
26
29
|
bearer_token_value = VellumSecret(name="secret")
|
27
30
|
|
28
|
-
node = SimpleBaseAPINode(
|
31
|
+
node = SimpleBaseAPINode()
|
29
32
|
terminal = node.run()
|
30
33
|
|
31
34
|
assert vellum_client.execute_api.call_count == 1
|
32
35
|
bearer_token = vellum_client.execute_api.call_args.kwargs["bearer_token"]
|
33
36
|
assert bearer_token == ClientVellumSecret(name="secret")
|
34
37
|
assert terminal.headers == {"X-Response-Header": "bar"}
|
38
|
+
|
39
|
+
|
40
|
+
def test_api_node_raises_error_when_api_call_fails(vellum_client):
|
41
|
+
# GIVEN an API call that fails
|
42
|
+
vellum_client.execute_api.side_effect = ApiError(status_code=400, body="API Error")
|
43
|
+
|
44
|
+
class SimpleAPINode(APINode):
|
45
|
+
method = APIRequestMethod.GET
|
46
|
+
authorization_type = AuthorizationType.BEARER_TOKEN
|
47
|
+
url = "https://api.vellum.ai"
|
48
|
+
body = {
|
49
|
+
"key": "value",
|
50
|
+
}
|
51
|
+
headers = {
|
52
|
+
"X-Test-Header": "foo",
|
53
|
+
}
|
54
|
+
bearer_token_value = VellumSecret(name="api_key")
|
55
|
+
|
56
|
+
node = SimpleAPINode()
|
57
|
+
|
58
|
+
# WHEN we run the node
|
59
|
+
with pytest.raises(NodeException) as excinfo:
|
60
|
+
node.run()
|
61
|
+
|
62
|
+
# THEN an exception should be raised
|
63
|
+
assert "Failed to prepare HTTP request" in str(excinfo.value)
|
64
|
+
|
65
|
+
# AND the API call should have been made
|
66
|
+
assert vellum_client.execute_api.call_count == 1
|
@@ -89,7 +89,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
|
|
89
89
|
url=url, method=method.value, body=data, headers=headers, bearer_token=client_vellum_secret
|
90
90
|
)
|
91
91
|
except ApiError as e:
|
92
|
-
NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.NODE_EXECUTION)
|
92
|
+
raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.NODE_EXECUTION)
|
93
93
|
|
94
94
|
return self.Outputs(
|
95
95
|
json=vellum_response.json_,
|
@@ -19,6 +19,7 @@ from vellum import (
|
|
19
19
|
VellumError,
|
20
20
|
VellumValue,
|
21
21
|
)
|
22
|
+
from vellum.client.core.api_error import ApiError
|
22
23
|
from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
|
23
24
|
from vellum.core import RequestOptions
|
24
25
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -103,14 +104,23 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
|
|
103
104
|
input_values = self._compile_code_inputs()
|
104
105
|
expected_output_type = primitive_type_to_vellum_variable_type(output_type)
|
105
106
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
107
|
+
try:
|
108
|
+
code_execution_result = self._context.vellum_client.execute_code(
|
109
|
+
input_values=input_values,
|
110
|
+
code=code,
|
111
|
+
runtime=self.runtime,
|
112
|
+
output_type=expected_output_type,
|
113
|
+
packages=self.packages or [],
|
114
|
+
request_options=self.request_options,
|
115
|
+
)
|
116
|
+
except ApiError as e:
|
117
|
+
if e.status_code == 400 and isinstance(e.body, dict) and "message" in e.body:
|
118
|
+
raise NodeException(
|
119
|
+
message=e.body["message"],
|
120
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
121
|
+
)
|
122
|
+
|
123
|
+
raise
|
114
124
|
|
115
125
|
if code_execution_result.output.type != expected_output_type:
|
116
126
|
actual_type = code_execution_result.output.type
|
@@ -5,6 +5,7 @@ from typing import Any, List, Union
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
7
|
from vellum import CodeExecutorResponse, NumberVellumValue, StringInput, StringVellumValue
|
8
|
+
from vellum.client.errors.bad_request_error import BadRequestError
|
8
9
|
from vellum.client.types.chat_message import ChatMessage
|
9
10
|
from vellum.client.types.code_execution_package import CodeExecutionPackage
|
10
11
|
from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
|
@@ -17,6 +18,7 @@ from vellum.workflows.inputs.base import BaseInputs
|
|
17
18
|
from vellum.workflows.nodes.displayable.code_execution_node import CodeExecutionNode
|
18
19
|
from vellum.workflows.references.vellum_secret import VellumSecretReference
|
19
20
|
from vellum.workflows.state.base import BaseState, StateMeta
|
21
|
+
from vellum.workflows.types.core import Json
|
20
22
|
|
21
23
|
|
22
24
|
def test_run_node__happy_path(vellum_client):
|
@@ -690,3 +692,54 @@ def main():
|
|
690
692
|
"result": [ChatMessage(role="USER", content=StringChatMessageContent(value="Hello, world!"))],
|
691
693
|
"log": "",
|
692
694
|
}
|
695
|
+
|
696
|
+
|
697
|
+
def test_run_node__execute_code_api_fails__node_exception(vellum_client):
|
698
|
+
# GIVEN a node that will throw a JSON.parse error
|
699
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, Json]):
|
700
|
+
code = """\
|
701
|
+
async function main(inputs: {
|
702
|
+
data: string,
|
703
|
+
}): Promise<string> {
|
704
|
+
return JSON.parse(inputs.data)
|
705
|
+
}
|
706
|
+
"""
|
707
|
+
code_inputs = {
|
708
|
+
"data": "not a valid json string",
|
709
|
+
}
|
710
|
+
runtime = "TYPESCRIPT_5_3_3"
|
711
|
+
|
712
|
+
# AND the execute_code API will fail
|
713
|
+
message = """\
|
714
|
+
Code execution error (exit code 1): undefined:1
|
715
|
+
not a valid json string
|
716
|
+
^
|
717
|
+
|
718
|
+
SyntaxError: Unexpected token 'o', \"not a valid\"... is not valid JSON
|
719
|
+
at JSON.parse (<anonymous>)
|
720
|
+
at Object.eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:40:40)
|
721
|
+
at step (eval at execute (/workdir/runner.js:16:18), <anonymous>:32:23)
|
722
|
+
at Object.eval [as next] (eval at execute (/workdir/runner.js:16:18), <anonymous>:13:53)
|
723
|
+
at eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:7:71)
|
724
|
+
at new Promise (<anonymous>)
|
725
|
+
at __awaiter (eval at execute (/workdir/runner.js:16:18), <anonymous>:3:12)
|
726
|
+
at Object.main (eval at execute (/workdir/runner.js:16:18), <anonymous>:38:12)
|
727
|
+
at execute (/workdir/runner.js:17:33)
|
728
|
+
at Interface.<anonymous> (/workdir/runner.js:58:5)
|
729
|
+
|
730
|
+
Node.js v21.7.3
|
731
|
+
"""
|
732
|
+
vellum_client.execute_code.side_effect = BadRequestError(
|
733
|
+
body={
|
734
|
+
"message": message,
|
735
|
+
"log": "",
|
736
|
+
}
|
737
|
+
)
|
738
|
+
|
739
|
+
# WHEN we run the node
|
740
|
+
node = ExampleCodeExecutionNode()
|
741
|
+
with pytest.raises(NodeException) as exc_info:
|
742
|
+
node.run()
|
743
|
+
|
744
|
+
# AND the error should contain the execution error details
|
745
|
+
assert exc_info.value.message == message
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from copy import deepcopy
|
3
|
+
from dataclasses import dataclass
|
3
4
|
import logging
|
4
5
|
from queue import Empty, Queue
|
5
6
|
from threading import Event as ThreadingEvent, Thread
|
@@ -63,6 +64,12 @@ ExternalInputsArg = Dict[ExternalInputReference, Any]
|
|
63
64
|
BackgroundThreadItem = Union[BaseState, WorkflowEvent, None]
|
64
65
|
|
65
66
|
|
67
|
+
@dataclass
|
68
|
+
class ActiveNode(Generic[StateType]):
|
69
|
+
node: BaseNode[StateType]
|
70
|
+
was_outputs_streamed: bool = False
|
71
|
+
|
72
|
+
|
66
73
|
class WorkflowRunner(Generic[StateType]):
|
67
74
|
_entrypoints: Iterable[Type[BaseNode]]
|
68
75
|
|
@@ -136,7 +143,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
136
143
|
self._dependencies: Dict[Type[BaseNode], Set[Type[BaseNode]]] = defaultdict(set)
|
137
144
|
self._state_forks: Set[StateType] = {self._initial_state}
|
138
145
|
|
139
|
-
self._active_nodes_by_execution_id: Dict[UUID,
|
146
|
+
self._active_nodes_by_execution_id: Dict[UUID, ActiveNode[StateType]] = {}
|
140
147
|
self._cancel_signal = cancel_signal
|
141
148
|
self._execution_context = init_execution_context or get_execution_context()
|
142
149
|
self._parent_context = self._execution_context.parent_context
|
@@ -404,7 +411,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
404
411
|
current_parent = get_parent_context()
|
405
412
|
node = node_class(state=state, context=self.workflow.context)
|
406
413
|
state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
|
407
|
-
self._active_nodes_by_execution_id[node_span_id] = node
|
414
|
+
self._active_nodes_by_execution_id[node_span_id] = ActiveNode(node=node)
|
408
415
|
|
409
416
|
worker_thread = Thread(
|
410
417
|
target=self._context_run_work_item,
|
@@ -413,10 +420,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
413
420
|
worker_thread.start()
|
414
421
|
|
415
422
|
def _handle_work_item_event(self, event: WorkflowEvent) -> Optional[WorkflowError]:
|
416
|
-
|
417
|
-
if not
|
423
|
+
active_node = self._active_nodes_by_execution_id.get(event.span_id)
|
424
|
+
if not active_node:
|
418
425
|
return None
|
419
426
|
|
427
|
+
node = active_node.node
|
420
428
|
if event.name == "node.execution.rejected":
|
421
429
|
self._active_nodes_by_execution_id.pop(event.span_id)
|
422
430
|
return event.error
|
@@ -431,6 +439,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
431
439
|
if node_output_descriptor.name != event.output.name:
|
432
440
|
continue
|
433
441
|
|
442
|
+
active_node.was_outputs_streamed = True
|
434
443
|
self._workflow_event_outer_queue.put(
|
435
444
|
self._stream_workflow_event(
|
436
445
|
BaseOutput(
|
@@ -447,6 +456,26 @@ class WorkflowRunner(Generic[StateType]):
|
|
447
456
|
|
448
457
|
if event.name == "node.execution.fulfilled":
|
449
458
|
self._active_nodes_by_execution_id.pop(event.span_id)
|
459
|
+
if not active_node.was_outputs_streamed:
|
460
|
+
for event_node_output_descriptor, node_output_value in event.outputs:
|
461
|
+
for workflow_output_descriptor in self.workflow.Outputs:
|
462
|
+
node_output_descriptor = workflow_output_descriptor.instance
|
463
|
+
if not isinstance(node_output_descriptor, OutputReference):
|
464
|
+
continue
|
465
|
+
if node_output_descriptor.outputs_class != event.node_definition.Outputs:
|
466
|
+
continue
|
467
|
+
if node_output_descriptor.name != event_node_output_descriptor.name:
|
468
|
+
continue
|
469
|
+
|
470
|
+
self._workflow_event_outer_queue.put(
|
471
|
+
self._stream_workflow_event(
|
472
|
+
BaseOutput(
|
473
|
+
name=workflow_output_descriptor.name,
|
474
|
+
value=node_output_value,
|
475
|
+
)
|
476
|
+
)
|
477
|
+
)
|
478
|
+
|
450
479
|
self._handle_invoked_ports(node.state, event.invoked_ports)
|
451
480
|
|
452
481
|
return None
|
@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Type
|
|
8
8
|
|
9
9
|
from pydantic import BaseModel
|
10
10
|
|
11
|
+
from vellum.workflows.constants import undefined
|
11
12
|
from vellum.workflows.inputs.base import BaseInputs
|
12
13
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
13
14
|
from vellum.workflows.ports.port import Port
|
@@ -22,7 +23,7 @@ class DefaultStateEncoder(JSONEncoder):
|
|
22
23
|
return dict(obj)
|
23
24
|
|
24
25
|
if isinstance(obj, (BaseInputs, BaseOutputs)):
|
25
|
-
return {descriptor.name: value for descriptor, value in obj}
|
26
|
+
return {descriptor.name: value for descriptor, value in obj if value is not undefined}
|
26
27
|
|
27
28
|
if isinstance(obj, (BaseOutput, Port)):
|
28
29
|
return obj.serialize()
|