vellum-ai 0.14.8__py3-none-any.whl → 0.14.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/utils/templating/render.py +4 -1
  3. vellum/workflows/descriptors/base.py +6 -0
  4. vellum/workflows/descriptors/tests/test_utils.py +14 -0
  5. vellum/workflows/events/tests/test_event.py +40 -0
  6. vellum/workflows/events/workflow.py +21 -1
  7. vellum/workflows/expressions/greater_than.py +15 -8
  8. vellum/workflows/expressions/greater_than_or_equal_to.py +14 -8
  9. vellum/workflows/expressions/less_than.py +14 -8
  10. vellum/workflows/expressions/less_than_or_equal_to.py +14 -8
  11. vellum/workflows/expressions/parse_json.py +30 -0
  12. vellum/workflows/expressions/tests/__init__.py +0 -0
  13. vellum/workflows/expressions/tests/test_expressions.py +310 -0
  14. vellum/workflows/expressions/tests/test_parse_json.py +31 -0
  15. vellum/workflows/nodes/bases/base.py +5 -2
  16. vellum/workflows/nodes/core/templating_node/node.py +0 -1
  17. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +49 -0
  18. vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +6 -7
  19. vellum/workflows/nodes/displayable/code_execution_node/node.py +18 -8
  20. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +53 -0
  21. vellum/workflows/nodes/utils.py +14 -3
  22. vellum/workflows/runner/runner.py +34 -9
  23. vellum/workflows/state/encoder.py +2 -1
  24. {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/METADATA +1 -1
  25. {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/RECORD +35 -31
  26. vellum_ee/workflows/display/nodes/base_node_display.py +4 -0
  27. vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +31 -0
  28. vellum_ee/workflows/display/types.py +1 -14
  29. vellum_ee/workflows/display/workflows/base_workflow_display.py +38 -18
  30. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +27 -0
  31. vellum_ee/workflows/tests/test_display_meta.py +2 -0
  32. vellum_ee/workflows/tests/test_server.py +1 -0
  33. {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/LICENSE +0 -0
  34. {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/WHEEL +0 -0
  35. {vellum_ai-0.14.8.dist-info → vellum_ai-0.14.10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,310 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.expressions.greater_than import GreaterThanExpression
4
+ from vellum.workflows.expressions.greater_than_or_equal_to import GreaterThanOrEqualToExpression
5
+ from vellum.workflows.expressions.less_than import LessThanExpression
6
+ from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+
10
+ class Comparable:
11
+ """A custom class with two values, where comparisons use a computed metric (multiplication)."""
12
+
13
+ def __init__(self, value1, value2):
14
+ self.value1 = value1 # First numerical value
15
+ self.value2 = value2 # Second numerical value
16
+
17
+ def computed_value(self):
18
+ return self.value1 * self.value2 # Multiply for comparison
19
+
20
+ def __ge__(self, other):
21
+ if isinstance(other, Comparable):
22
+ return self.computed_value() >= other.computed_value()
23
+ elif isinstance(other, (int, float)):
24
+ return self.computed_value() >= other
25
+ return NotImplemented
26
+
27
+ def __gt__(self, other):
28
+ if isinstance(other, Comparable):
29
+ return self.computed_value() > other.computed_value()
30
+ elif isinstance(other, (int, float)):
31
+ return self.computed_value() > other
32
+ return NotImplemented
33
+
34
+ def __le__(self, other):
35
+ if isinstance(other, Comparable):
36
+ return self.computed_value() <= other.computed_value()
37
+ elif isinstance(other, (int, float)):
38
+ return self.computed_value() <= other
39
+ return NotImplemented
40
+
41
+ def __lt__(self, other):
42
+ if isinstance(other, Comparable):
43
+ return self.computed_value() < other.computed_value()
44
+ elif isinstance(other, (int, float)):
45
+ return self.computed_value() < other
46
+ return NotImplemented
47
+
48
+
49
+ class NonComparable:
50
+ """A custom class that does not support comparisons."""
51
+
52
+ def __init__(self, value1, value2):
53
+ self.value1 = value1
54
+ self.value2 = value2
55
+
56
+
57
+ class TestState(BaseState):
58
+ pass
59
+
60
+
61
+ def test_greater_than_or_equal_to():
62
+ # GIVEN objects with two values
63
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
64
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
65
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
66
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
67
+
68
+ state = TestState()
69
+
70
+ # WHEN comparing objects
71
+ assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 >= 20
72
+ assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 >= 18
73
+ assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
74
+
75
+ # WHEN comparing to raw numbers
76
+ assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 >= 19
77
+ assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
78
+
79
+
80
+ def test_greater_than_or_equal_to_invalid():
81
+ # GIVEN objects with two values
82
+ obj1 = Comparable(4, 5)
83
+ obj2 = "invalid"
84
+
85
+ state = TestState()
86
+
87
+ # WHEN comparing objects with incompatible types
88
+ with pytest.raises(TypeError) as exc_info:
89
+ GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
90
+
91
+ # THEN the expected error is raised
92
+ assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'str'"
93
+
94
+ # WHEN comparing objects with incompatible types
95
+ with pytest.raises(TypeError) as exc_info:
96
+ GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
97
+
98
+ # THEN the expected error is raised
99
+ assert str(exc_info.value) == "'>=' not supported between instances of 'str' and 'Comparable'"
100
+
101
+
102
+ def test_greater_than_or_equal_to_non_comparable():
103
+ # GIVEN objects with two values
104
+ obj1 = Comparable(4, 5)
105
+ obj2 = NonComparable(2, 10)
106
+
107
+ state = TestState()
108
+
109
+ # WHEN comparing objects with incompatible types
110
+ with pytest.raises(TypeError) as exc_info:
111
+ GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
112
+
113
+ # THEN the expected error is raised
114
+ assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'NonComparable'"
115
+
116
+ # WHEN comparing objects with incompatible types
117
+ with pytest.raises(TypeError) as exc_info:
118
+ GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
119
+
120
+ # THEN the expected error is raised
121
+ assert str(exc_info.value) == "'>=' not supported between instances of 'NonComparable' and 'Comparable'"
122
+
123
+
124
+ def test_greater_than():
125
+ # GIVEN objects with two values
126
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
127
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
128
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
129
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
130
+
131
+ state = TestState()
132
+
133
+ # WHEN comparing objects
134
+ assert GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 > 20
135
+ assert GreaterThanExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 > 18
136
+ assert GreaterThanExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
137
+
138
+ # WHEN comparing to raw numbers
139
+ assert GreaterThanExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 > 19
140
+ assert GreaterThanExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
141
+
142
+
143
+ def test_greater_than_invalid():
144
+ # GIVEN objects with two values
145
+ obj1 = Comparable(4, 5)
146
+ obj2 = "invalid"
147
+
148
+ state = TestState()
149
+
150
+ # WHEN comparing objects with incompatible types
151
+ with pytest.raises(TypeError) as exc_info:
152
+ GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
153
+
154
+ # THEN the expected error is raised
155
+ assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'str'"
156
+
157
+ # WHEN comparing objects with incompatible types
158
+ with pytest.raises(TypeError) as exc_info:
159
+ GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
160
+
161
+ # THEN the expected error is raised
162
+ assert str(exc_info.value) == "'>' not supported between instances of 'str' and 'Comparable'"
163
+
164
+
165
+ def test_greater_than_non_comparable():
166
+ # GIVEN objects with two values
167
+ obj1 = Comparable(4, 5)
168
+ obj2 = NonComparable(2, 10)
169
+
170
+ state = TestState()
171
+
172
+ # WHEN comparing objects with incompatible types
173
+ with pytest.raises(TypeError) as exc_info:
174
+ GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
175
+
176
+ # THEN the expected error is raised
177
+ assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'NonComparable'"
178
+
179
+ # WHEN comparing objects with incompatible types
180
+ with pytest.raises(TypeError) as exc_info:
181
+ GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
182
+
183
+ # THEN the expected error is raised
184
+ assert str(exc_info.value) == "'>' not supported between instances of 'NonComparable' and 'Comparable'"
185
+
186
+
187
+ def test_less_than_or_equal_to():
188
+ # GIVEN objects with two values
189
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
190
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
191
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
192
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
193
+
194
+ state = TestState()
195
+
196
+ # WHEN comparing objects
197
+ assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 <= 20
198
+ assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 > 18
199
+ assert LessThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 <= 25
200
+
201
+ # WHEN comparing to raw numbers
202
+ assert LessThanOrEqualToExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 <= 21
203
+ assert LessThanOrEqualToExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
204
+
205
+
206
+ def test_less_than_or_equal_to_invalid():
207
+ # GIVEN objects with two values
208
+ obj1 = Comparable(4, 5)
209
+ obj2 = "invalid"
210
+
211
+ state = TestState()
212
+
213
+ # WHEN comparing objects with incompatible types
214
+ with pytest.raises(TypeError) as exc_info:
215
+ LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
216
+
217
+ # THEN the expected error is raised
218
+ assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'str'"
219
+
220
+ # WHEN comparing objects with incompatible types
221
+ with pytest.raises(TypeError) as exc_info:
222
+ LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
223
+
224
+ # THEN the expected error is raised
225
+ assert str(exc_info.value) == "'<=' not supported between instances of 'str' and 'Comparable'"
226
+
227
+
228
+ def test_less_than_or_equal_to_non_comparable():
229
+ # GIVEN objects with two values
230
+ obj1 = Comparable(4, 5)
231
+ obj2 = NonComparable(2, 10)
232
+
233
+ state = TestState()
234
+
235
+ # WHEN comparing objects with incompatible types
236
+ with pytest.raises(TypeError) as exc_info:
237
+ LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
238
+
239
+ # THEN the expected error is raised
240
+ assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'NonComparable'"
241
+
242
+ # WHEN comparing objects with incompatible types
243
+ with pytest.raises(TypeError) as exc_info:
244
+ LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
245
+
246
+ # THEN the expected error is raised
247
+ assert str(exc_info.value) == "'<=' not supported between instances of 'NonComparable' and 'Comparable'"
248
+
249
+
250
+ def test_less_than():
251
+ # GIVEN objects with two values
252
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
253
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
254
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
255
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
256
+
257
+ state = TestState()
258
+
259
+ # WHEN comparing objects
260
+ assert LessThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 < 20
261
+ assert LessThanExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 < 18
262
+ assert LessThanExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 < 25
263
+
264
+ # WHEN comparing to raw numbers
265
+ assert LessThanExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 < 21
266
+ assert LessThanExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
267
+
268
+
269
+ def test_less_than_invalid():
270
+ # GIVEN objects with two values
271
+ obj1 = Comparable(4, 5)
272
+ obj2 = "invalid"
273
+
274
+ state = TestState()
275
+
276
+ # WHEN comparing objects with incompatible types
277
+ with pytest.raises(TypeError) as exc_info:
278
+ LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
279
+
280
+ # THEN the expected error is raised
281
+ assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'str'"
282
+
283
+ # WHEN comparing objects with incompatible types
284
+ with pytest.raises(TypeError) as exc_info:
285
+ LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
286
+
287
+ # THEN the expected error is raised
288
+ assert str(exc_info.value) == "'<' not supported between instances of 'str' and 'Comparable'"
289
+
290
+
291
+ def test_less_than_non_comparable():
292
+ # GIVEN objects with two values
293
+ obj1 = Comparable(4, 5)
294
+ obj2 = NonComparable(2, 10)
295
+
296
+ state = TestState()
297
+
298
+ # WHEN comparing objects with incompatible types
299
+ with pytest.raises(TypeError) as exc_info:
300
+ LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
301
+
302
+ # THEN the expected error is raised
303
+ assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'NonComparable'"
304
+
305
+ # WHEN comparing objects with incompatible types
306
+ with pytest.raises(TypeError) as exc_info:
307
+ LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
308
+
309
+ # THEN the expected error is raised
310
+ assert str(exc_info.value) == "'<' not supported between instances of 'NonComparable' and 'Comparable'"
@@ -0,0 +1,31 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
4
+ from vellum.workflows.references.constant import ConstantValueReference
5
+ from vellum.workflows.state.base import BaseState
6
+
7
+
8
+ def test_parse_json_invalid_json():
9
+ # GIVEN an invalid JSON string
10
+ state = BaseState()
11
+ expression = ConstantValueReference('{"key": value}').parse_json()
12
+
13
+ # WHEN we attempt to resolve the expression
14
+ with pytest.raises(InvalidExpressionException) as exc_info:
15
+ expression.resolve(state)
16
+
17
+ # THEN an exception should be raised
18
+ assert "Failed to parse JSON" in str(exc_info.value)
19
+
20
+
21
+ def test_parse_json_invalid_type():
22
+ # GIVEN a non-string value
23
+ state = BaseState()
24
+ expression = ConstantValueReference(123).parse_json()
25
+
26
+ # WHEN we attempt to resolve the expression
27
+ with pytest.raises(InvalidExpressionException) as exc_info:
28
+ expression.resolve(state)
29
+
30
+ # THEN an exception should be raised
31
+ assert "Expected a string, but got 123 of type <class 'int'>" == str(exc_info.value)
@@ -318,9 +318,12 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
318
318
  original_base = get_original_base(self.__class__)
319
319
 
320
320
  args = get_args(original_base)
321
- state_type = args[0]
322
321
 
323
- if isinstance(state_type, TypeVar):
322
+ if args and len(args) > 0:
323
+ state_type = args[0]
324
+ if isinstance(state_type, TypeVar):
325
+ state_type = BaseState
326
+ else:
324
327
  state_type = BaseState
325
328
 
326
329
  self.state = state_type()
@@ -82,7 +82,6 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
82
82
  def run(self) -> Outputs:
83
83
  rendered_template = self._render_template()
84
84
  result = self._cast_rendered_template(rendered_template)
85
-
86
85
  return self.Outputs(result=result)
87
86
 
88
87
  def _render_template(self) -> str:
@@ -234,3 +234,52 @@ def test_templating_node__last_chat_message():
234
234
 
235
235
  # THEN the output is the expected JSON
236
236
  assert outputs.result == [ChatMessage(role="USER", text="Hello")]
237
+
238
+
239
+ def test_templating_node__function_call_value_input():
240
+ # GIVEN a templating node that receives a FunctionCallVellumValue
241
+ class FunctionCallTemplateNode(TemplatingNode[BaseState, FunctionCall]):
242
+ template = """{{ function_call }}"""
243
+ inputs = {
244
+ "function_call": FunctionCallVellumValue(
245
+ type="FUNCTION_CALL",
246
+ value=FunctionCall(name="test_function", arguments={"key": "value"}, id="test_id", state="FULFILLED"),
247
+ )
248
+ }
249
+
250
+ # WHEN the node is run
251
+ node = FunctionCallTemplateNode()
252
+ outputs = node.run()
253
+
254
+ # THEN the output is the expected function call
255
+ assert outputs.result == FunctionCall(
256
+ name="test_function", arguments={"key": "value"}, id="test_id", state="FULFILLED"
257
+ )
258
+
259
+
260
+ def test_templating_node__function_call_as_json():
261
+ # GIVEN a node that receives a FunctionCallVellumValue but outputs as Json
262
+ class JsonOutputNode(TemplatingNode[BaseState, Json]):
263
+ template = """{{ function_call }}"""
264
+ inputs = {
265
+ "function_call": FunctionCallVellumValue(
266
+ type="FUNCTION_CALL",
267
+ value=FunctionCall(name="test_function", arguments={"key": "value"}, id="test_id", state="FULFILLED"),
268
+ )
269
+ }
270
+
271
+ # WHEN the node is run
272
+ node = JsonOutputNode()
273
+ outputs = node.run()
274
+
275
+ # THEN we get just the FunctionCall data as JSON
276
+ assert outputs.result == {
277
+ "name": "test_function",
278
+ "arguments": {"key": "value"},
279
+ "id": "test_id",
280
+ "state": "FULFILLED",
281
+ }
282
+
283
+ # AND we can access fields directly
284
+ assert outputs.result["arguments"] == {"key": "value"}
285
+ assert outputs.result["name"] == "test_function"
@@ -5,7 +5,6 @@ from vellum.client.core.api_error import ApiError
5
5
  from vellum.workflows.constants import APIRequestMethod, AuthorizationType
6
6
  from vellum.workflows.exceptions import NodeException
7
7
  from vellum.workflows.nodes import APINode
8
- from vellum.workflows.state import BaseState
9
8
  from vellum.workflows.types.core import VellumSecret
10
9
 
11
10
 
@@ -29,7 +28,7 @@ def test_run_workflow__secrets(vellum_client):
29
28
  }
30
29
  bearer_token_value = VellumSecret(name="secret")
31
30
 
32
- node = SimpleBaseAPINode(state=BaseState())
31
+ node = SimpleBaseAPINode()
33
32
  terminal = node.run()
34
33
 
35
34
  assert vellum_client.execute_api.call_count == 1
@@ -39,7 +38,7 @@ def test_run_workflow__secrets(vellum_client):
39
38
 
40
39
 
41
40
  def test_api_node_raises_error_when_api_call_fails(vellum_client):
42
- # Mock the vellum_client to raise an ApiError
41
+ # GIVEN an API call that fails
43
42
  vellum_client.execute_api.side_effect = ApiError(status_code=400, body="API Error")
44
43
 
45
44
  class SimpleAPINode(APINode):
@@ -54,14 +53,14 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
54
53
  }
55
54
  bearer_token_value = VellumSecret(name="api_key")
56
55
 
57
- node = SimpleAPINode(state=BaseState())
56
+ node = SimpleAPINode()
58
57
 
59
- # Assert that the NodeException is raised
58
+ # WHEN we run the node
60
59
  with pytest.raises(NodeException) as excinfo:
61
60
  node.run()
62
61
 
63
- # Verify that the exception contains some error message
62
+ # THEN an exception should be raised
64
63
  assert "Failed to prepare HTTP request" in str(excinfo.value)
65
64
 
66
- # Verify the vellum_client was called
65
+ # AND the API call should have been made
67
66
  assert vellum_client.execute_api.call_count == 1
@@ -19,6 +19,7 @@ from vellum import (
19
19
  VellumError,
20
20
  VellumValue,
21
21
  )
22
+ from vellum.client.core.api_error import ApiError
22
23
  from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
23
24
  from vellum.core import RequestOptions
24
25
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -103,14 +104,23 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
103
104
  input_values = self._compile_code_inputs()
104
105
  expected_output_type = primitive_type_to_vellum_variable_type(output_type)
105
106
 
106
- code_execution_result = self._context.vellum_client.execute_code(
107
- input_values=input_values,
108
- code=code,
109
- runtime=self.runtime,
110
- output_type=expected_output_type,
111
- packages=self.packages or [],
112
- request_options=self.request_options,
113
- )
107
+ try:
108
+ code_execution_result = self._context.vellum_client.execute_code(
109
+ input_values=input_values,
110
+ code=code,
111
+ runtime=self.runtime,
112
+ output_type=expected_output_type,
113
+ packages=self.packages or [],
114
+ request_options=self.request_options,
115
+ )
116
+ except ApiError as e:
117
+ if e.status_code == 400 and isinstance(e.body, dict) and "message" in e.body:
118
+ raise NodeException(
119
+ message=e.body["message"],
120
+ code=WorkflowErrorCode.INVALID_INPUTS,
121
+ )
122
+
123
+ raise
114
124
 
115
125
  if code_execution_result.output.type != expected_output_type:
116
126
  actual_type = code_execution_result.output.type
@@ -5,6 +5,7 @@ from typing import Any, List, Union
5
5
  from pydantic import BaseModel
6
6
 
7
7
  from vellum import CodeExecutorResponse, NumberVellumValue, StringInput, StringVellumValue
8
+ from vellum.client.errors.bad_request_error import BadRequestError
8
9
  from vellum.client.types.chat_message import ChatMessage
9
10
  from vellum.client.types.code_execution_package import CodeExecutionPackage
10
11
  from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
@@ -17,6 +18,7 @@ from vellum.workflows.inputs.base import BaseInputs
17
18
  from vellum.workflows.nodes.displayable.code_execution_node import CodeExecutionNode
18
19
  from vellum.workflows.references.vellum_secret import VellumSecretReference
19
20
  from vellum.workflows.state.base import BaseState, StateMeta
21
+ from vellum.workflows.types.core import Json
20
22
 
21
23
 
22
24
  def test_run_node__happy_path(vellum_client):
@@ -690,3 +692,54 @@ def main():
690
692
  "result": [ChatMessage(role="USER", content=StringChatMessageContent(value="Hello, world!"))],
691
693
  "log": "",
692
694
  }
695
+
696
+
697
+ def test_run_node__execute_code_api_fails__node_exception(vellum_client):
698
+ # GIVEN a node that will throw a JSON.parse error
699
+ class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, Json]):
700
+ code = """\
701
+ async function main(inputs: {
702
+ data: string,
703
+ }): Promise<string> {
704
+ return JSON.parse(inputs.data)
705
+ }
706
+ """
707
+ code_inputs = {
708
+ "data": "not a valid json string",
709
+ }
710
+ runtime = "TYPESCRIPT_5_3_3"
711
+
712
+ # AND the execute_code API will fail
713
+ message = """\
714
+ Code execution error (exit code 1): undefined:1
715
+ not a valid json string
716
+ ^
717
+
718
+ SyntaxError: Unexpected token 'o', \"not a valid\"... is not valid JSON
719
+ at JSON.parse (<anonymous>)
720
+ at Object.eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:40:40)
721
+ at step (eval at execute (/workdir/runner.js:16:18), <anonymous>:32:23)
722
+ at Object.eval [as next] (eval at execute (/workdir/runner.js:16:18), <anonymous>:13:53)
723
+ at eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:7:71)
724
+ at new Promise (<anonymous>)
725
+ at __awaiter (eval at execute (/workdir/runner.js:16:18), <anonymous>:3:12)
726
+ at Object.main (eval at execute (/workdir/runner.js:16:18), <anonymous>:38:12)
727
+ at execute (/workdir/runner.js:17:33)
728
+ at Interface.<anonymous> (/workdir/runner.js:58:5)
729
+
730
+ Node.js v21.7.3
731
+ """
732
+ vellum_client.execute_code.side_effect = BadRequestError(
733
+ body={
734
+ "message": message,
735
+ "log": "",
736
+ }
737
+ )
738
+
739
+ # WHEN we run the node
740
+ node = ExampleCodeExecutionNode()
741
+ with pytest.raises(NodeException) as exc_info:
742
+ node.run()
743
+
744
+ # AND the error should contain the execution error details
745
+ assert exc_info.value.message == message
@@ -109,7 +109,11 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
109
109
 
110
110
  if output_type is Json:
111
111
  try:
112
- return json.loads(result_as_str)
112
+ data = json.loads(result_as_str)
113
+ # If we got a FunctionCallVellumValue, return just the value
114
+ if isinstance(data, dict) and data.get("type") == "FUNCTION_CALL" and "value" in data:
115
+ return data["value"]
116
+ return data
113
117
  except json.JSONDecodeError:
114
118
  raise ValueError("Invalid JSON format for result_as_str")
115
119
 
@@ -124,9 +128,16 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
124
128
  if issubclass(output_type, BaseModel):
125
129
  try:
126
130
  data = json.loads(result_as_str)
131
+ # If we got a FunctionCallVellumValue extract FunctionCall,
132
+ if (
133
+ hasattr(output_type, "__name__")
134
+ and output_type.__name__ == "FunctionCall"
135
+ and isinstance(data, dict)
136
+ and "value" in data
137
+ ):
138
+ data = data["value"]
139
+ return output_type.model_validate(data)
127
140
  except json.JSONDecodeError:
128
141
  raise ValueError("Invalid JSON format for result_as_str")
129
142
 
130
- return output_type.model_validate(data)
131
-
132
143
  raise ValueError(f"Unsupported output type: {output_type}")