vellum-ai 0.14.7__py3-none-any.whl → 0.14.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. vellum/__init__.py +2 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +2 -0
  4. vellum/client/types/document_prompt_block.py +29 -0
  5. vellum/client/types/prompt_block.py +2 -0
  6. vellum/types/document_prompt_block.py +3 -0
  7. vellum/workflows/descriptors/base.py +6 -0
  8. vellum/workflows/descriptors/tests/test_utils.py +14 -0
  9. vellum/workflows/events/tests/test_event.py +40 -0
  10. vellum/workflows/events/workflow.py +20 -1
  11. vellum/workflows/expressions/greater_than.py +15 -8
  12. vellum/workflows/expressions/greater_than_or_equal_to.py +14 -8
  13. vellum/workflows/expressions/less_than.py +14 -8
  14. vellum/workflows/expressions/less_than_or_equal_to.py +14 -8
  15. vellum/workflows/expressions/parse_json.py +30 -0
  16. vellum/workflows/expressions/tests/__init__.py +0 -0
  17. vellum/workflows/expressions/tests/test_expressions.py +310 -0
  18. vellum/workflows/expressions/tests/test_parse_json.py +31 -0
  19. vellum/workflows/nodes/bases/base.py +5 -2
  20. vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +34 -2
  21. vellum/workflows/nodes/displayable/bases/api_node/node.py +1 -1
  22. vellum/workflows/nodes/displayable/code_execution_node/node.py +18 -8
  23. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +53 -0
  24. vellum/workflows/runner/runner.py +33 -4
  25. vellum/workflows/state/encoder.py +2 -1
  26. {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/METADATA +1 -1
  27. {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/RECORD +44 -38
  28. vellum_cli/__init__.py +9 -2
  29. vellum_cli/config.py +1 -0
  30. vellum_cli/init.py +6 -2
  31. vellum_cli/pull.py +1 -0
  32. vellum_cli/tests/test_init.py +194 -76
  33. vellum_cli/tests/test_pull.py +8 -0
  34. vellum_cli/tests/test_push.py +1 -0
  35. vellum_ee/workflows/display/nodes/base_node_display.py +4 -0
  36. vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +114 -0
  37. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +118 -3
  38. vellum_ee/workflows/display/types.py +1 -14
  39. vellum_ee/workflows/display/workflows/base_workflow_display.py +48 -19
  40. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +12 -0
  41. vellum_ee/workflows/tests/test_server.py +1 -0
  42. {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/LICENSE +0 -0
  43. {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/WHEEL +0 -0
  44. {vellum_ai-0.14.7.dist-info → vellum_ai-0.14.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,310 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.expressions.greater_than import GreaterThanExpression
4
+ from vellum.workflows.expressions.greater_than_or_equal_to import GreaterThanOrEqualToExpression
5
+ from vellum.workflows.expressions.less_than import LessThanExpression
6
+ from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+
10
+ class Comparable:
11
+ """A custom class with two values, where comparisons use a computed metric (multiplication)."""
12
+
13
+ def __init__(self, value1, value2):
14
+ self.value1 = value1 # First numerical value
15
+ self.value2 = value2 # Second numerical value
16
+
17
+ def computed_value(self):
18
+ return self.value1 * self.value2 # Multiply for comparison
19
+
20
+ def __ge__(self, other):
21
+ if isinstance(other, Comparable):
22
+ return self.computed_value() >= other.computed_value()
23
+ elif isinstance(other, (int, float)):
24
+ return self.computed_value() >= other
25
+ return NotImplemented
26
+
27
+ def __gt__(self, other):
28
+ if isinstance(other, Comparable):
29
+ return self.computed_value() > other.computed_value()
30
+ elif isinstance(other, (int, float)):
31
+ return self.computed_value() > other
32
+ return NotImplemented
33
+
34
+ def __le__(self, other):
35
+ if isinstance(other, Comparable):
36
+ return self.computed_value() <= other.computed_value()
37
+ elif isinstance(other, (int, float)):
38
+ return self.computed_value() <= other
39
+ return NotImplemented
40
+
41
+ def __lt__(self, other):
42
+ if isinstance(other, Comparable):
43
+ return self.computed_value() < other.computed_value()
44
+ elif isinstance(other, (int, float)):
45
+ return self.computed_value() < other
46
+ return NotImplemented
47
+
48
+
49
+ class NonComparable:
50
+ """A custom class that does not support comparisons."""
51
+
52
+ def __init__(self, value1, value2):
53
+ self.value1 = value1
54
+ self.value2 = value2
55
+
56
+
57
+ class TestState(BaseState):
58
+ pass
59
+
60
+
61
+ def test_greater_than_or_equal_to():
62
+ # GIVEN objects with two values
63
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
64
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
65
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
66
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
67
+
68
+ state = TestState()
69
+
70
+ # WHEN comparing objects
71
+ assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 >= 20
72
+ assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 >= 18
73
+ assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
74
+
75
+ # WHEN comparing to raw numbers
76
+ assert GreaterThanOrEqualToExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 >= 19
77
+ assert GreaterThanOrEqualToExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
78
+
79
+
80
+ def test_greater_than_or_equal_to_invalid():
81
+ # GIVEN objects with two values
82
+ obj1 = Comparable(4, 5)
83
+ obj2 = "invalid"
84
+
85
+ state = TestState()
86
+
87
+ # WHEN comparing objects with incompatible types
88
+ with pytest.raises(TypeError) as exc_info:
89
+ GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
90
+
91
+ # THEN the expected error is raised
92
+ assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'str'"
93
+
94
+ # WHEN comparing objects with incompatible types
95
+ with pytest.raises(TypeError) as exc_info:
96
+ GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
97
+
98
+ # THEN the expected error is raised
99
+ assert str(exc_info.value) == "'>=' not supported between instances of 'str' and 'Comparable'"
100
+
101
+
102
+ def test_greater_than_or_equal_to_non_comparable():
103
+ # GIVEN objects with two values
104
+ obj1 = Comparable(4, 5)
105
+ obj2 = NonComparable(2, 10)
106
+
107
+ state = TestState()
108
+
109
+ # WHEN comparing objects with incompatible types
110
+ with pytest.raises(TypeError) as exc_info:
111
+ GreaterThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
112
+
113
+ # THEN the expected error is raised
114
+ assert str(exc_info.value) == "'>=' not supported between instances of 'Comparable' and 'NonComparable'"
115
+
116
+ # WHEN comparing objects with incompatible types
117
+ with pytest.raises(TypeError) as exc_info:
118
+ GreaterThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
119
+
120
+ # THEN the expected error is raised
121
+ assert str(exc_info.value) == "'>=' not supported between instances of 'NonComparable' and 'Comparable'"
122
+
123
+
124
+ def test_greater_than():
125
+ # GIVEN objects with two values
126
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
127
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
128
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
129
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
130
+
131
+ state = TestState()
132
+
133
+ # WHEN comparing objects
134
+ assert GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 > 20
135
+ assert GreaterThanExpression(lhs=obj1, rhs=obj3).resolve(state) is True # 20 > 18
136
+ assert GreaterThanExpression(lhs=obj3, rhs=obj4).resolve(state) is False # 18 < 25
137
+
138
+ # WHEN comparing to raw numbers
139
+ assert GreaterThanExpression(lhs=obj1, rhs=19).resolve(state) is True # 20 > 19
140
+ assert GreaterThanExpression(lhs=obj3, rhs=20).resolve(state) is False # 18 < 20
141
+
142
+
143
+ def test_greater_than_invalid():
144
+ # GIVEN objects with two values
145
+ obj1 = Comparable(4, 5)
146
+ obj2 = "invalid"
147
+
148
+ state = TestState()
149
+
150
+ # WHEN comparing objects with incompatible types
151
+ with pytest.raises(TypeError) as exc_info:
152
+ GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
153
+
154
+ # THEN the expected error is raised
155
+ assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'str'"
156
+
157
+ # WHEN comparing objects with incompatible types
158
+ with pytest.raises(TypeError) as exc_info:
159
+ GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
160
+
161
+ # THEN the expected error is raised
162
+ assert str(exc_info.value) == "'>' not supported between instances of 'str' and 'Comparable'"
163
+
164
+
165
+ def test_greater_than_non_comparable():
166
+ # GIVEN objects with two values
167
+ obj1 = Comparable(4, 5)
168
+ obj2 = NonComparable(2, 10)
169
+
170
+ state = TestState()
171
+
172
+ # WHEN comparing objects with incompatible types
173
+ with pytest.raises(TypeError) as exc_info:
174
+ GreaterThanExpression(lhs=obj1, rhs=obj2).resolve(state)
175
+
176
+ # THEN the expected error is raised
177
+ assert str(exc_info.value) == "'>' not supported between instances of 'Comparable' and 'NonComparable'"
178
+
179
+ # WHEN comparing objects with incompatible types
180
+ with pytest.raises(TypeError) as exc_info:
181
+ GreaterThanExpression(lhs=obj2, rhs=obj1).resolve(state)
182
+
183
+ # THEN the expected error is raised
184
+ assert str(exc_info.value) == "'>' not supported between instances of 'NonComparable' and 'Comparable'"
185
+
186
+
187
+ def test_less_than_or_equal_to():
188
+ # GIVEN objects with two values
189
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
190
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
191
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
192
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
193
+
194
+ state = TestState()
195
+
196
+ # WHEN comparing objects
197
+ assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state) is True # 20 <= 20
198
+ assert LessThanOrEqualToExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 > 18
199
+ assert LessThanOrEqualToExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 <= 25
200
+
201
+ # WHEN comparing to raw numbers
202
+ assert LessThanOrEqualToExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 <= 21
203
+ assert LessThanOrEqualToExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
204
+
205
+
206
+ def test_less_than_or_equal_to_invalid():
207
+ # GIVEN objects with two values
208
+ obj1 = Comparable(4, 5)
209
+ obj2 = "invalid"
210
+
211
+ state = TestState()
212
+
213
+ # WHEN comparing objects with incompatible types
214
+ with pytest.raises(TypeError) as exc_info:
215
+ LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
216
+
217
+ # THEN the expected error is raised
218
+ assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'str'"
219
+
220
+ # WHEN comparing objects with incompatible types
221
+ with pytest.raises(TypeError) as exc_info:
222
+ LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
223
+
224
+ # THEN the expected error is raised
225
+ assert str(exc_info.value) == "'<=' not supported between instances of 'str' and 'Comparable'"
226
+
227
+
228
+ def test_less_than_or_equal_to_non_comparable():
229
+ # GIVEN objects with two values
230
+ obj1 = Comparable(4, 5)
231
+ obj2 = NonComparable(2, 10)
232
+
233
+ state = TestState()
234
+
235
+ # WHEN comparing objects with incompatible types
236
+ with pytest.raises(TypeError) as exc_info:
237
+ LessThanOrEqualToExpression(lhs=obj1, rhs=obj2).resolve(state)
238
+
239
+ # THEN the expected error is raised
240
+ assert str(exc_info.value) == "'<=' not supported between instances of 'Comparable' and 'NonComparable'"
241
+
242
+ # WHEN comparing objects with incompatible types
243
+ with pytest.raises(TypeError) as exc_info:
244
+ LessThanOrEqualToExpression(lhs=obj2, rhs=obj1).resolve(state)
245
+
246
+ # THEN the expected error is raised
247
+ assert str(exc_info.value) == "'<=' not supported between instances of 'NonComparable' and 'Comparable'"
248
+
249
+
250
+ def test_less_than():
251
+ # GIVEN objects with two values
252
+ obj1 = Comparable(4, 5) # Computed: 4 × 5 = 20
253
+ obj2 = Comparable(2, 10) # Computed: 2 × 10 = 20
254
+ obj3 = Comparable(3, 6) # Computed: 3 × 6 = 18
255
+ obj4 = Comparable(5, 5) # Computed: 5 × 5 = 25
256
+
257
+ state = TestState()
258
+
259
+ # WHEN comparing objects
260
+ assert LessThanExpression(lhs=obj1, rhs=obj2).resolve(state) is False # 20 < 20
261
+ assert LessThanExpression(lhs=obj1, rhs=obj3).resolve(state) is False # 20 < 18
262
+ assert LessThanExpression(lhs=obj3, rhs=obj4).resolve(state) is True # 18 < 25
263
+
264
+ # WHEN comparing to raw numbers
265
+ assert LessThanExpression(lhs=obj1, rhs=21).resolve(state) is True # 20 < 21
266
+ assert LessThanExpression(lhs=obj3, rhs=17).resolve(state) is False # 18 > 17
267
+
268
+
269
+ def test_less_than_invalid():
270
+ # GIVEN objects with two values
271
+ obj1 = Comparable(4, 5)
272
+ obj2 = "invalid"
273
+
274
+ state = TestState()
275
+
276
+ # WHEN comparing objects with incompatible types
277
+ with pytest.raises(TypeError) as exc_info:
278
+ LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
279
+
280
+ # THEN the expected error is raised
281
+ assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'str'"
282
+
283
+ # WHEN comparing objects with incompatible types
284
+ with pytest.raises(TypeError) as exc_info:
285
+ LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
286
+
287
+ # THEN the expected error is raised
288
+ assert str(exc_info.value) == "'<' not supported between instances of 'str' and 'Comparable'"
289
+
290
+
291
+ def test_less_than_non_comparable():
292
+ # GIVEN objects with two values
293
+ obj1 = Comparable(4, 5)
294
+ obj2 = NonComparable(2, 10)
295
+
296
+ state = TestState()
297
+
298
+ # WHEN comparing objects with incompatible types
299
+ with pytest.raises(TypeError) as exc_info:
300
+ LessThanExpression(lhs=obj1, rhs=obj2).resolve(state)
301
+
302
+ # THEN the expected error is raised
303
+ assert str(exc_info.value) == "'<' not supported between instances of 'Comparable' and 'NonComparable'"
304
+
305
+ # WHEN comparing objects with incompatible types
306
+ with pytest.raises(TypeError) as exc_info:
307
+ LessThanExpression(lhs=obj2, rhs=obj1).resolve(state)
308
+
309
+ # THEN the expected error is raised
310
+ assert str(exc_info.value) == "'<' not supported between instances of 'NonComparable' and 'Comparable'"
@@ -0,0 +1,31 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
4
+ from vellum.workflows.references.constant import ConstantValueReference
5
+ from vellum.workflows.state.base import BaseState
6
+
7
+
8
+ def test_parse_json_invalid_json():
9
+ # GIVEN an invalid JSON string
10
+ state = BaseState()
11
+ expression = ConstantValueReference('{"key": value}').parse_json()
12
+
13
+ # WHEN we attempt to resolve the expression
14
+ with pytest.raises(InvalidExpressionException) as exc_info:
15
+ expression.resolve(state)
16
+
17
+ # THEN an exception should be raised
18
+ assert "Failed to parse JSON" in str(exc_info.value)
19
+
20
+
21
+ def test_parse_json_invalid_type():
22
+ # GIVEN a non-string value
23
+ state = BaseState()
24
+ expression = ConstantValueReference(123).parse_json()
25
+
26
+ # WHEN we attempt to resolve the expression
27
+ with pytest.raises(InvalidExpressionException) as exc_info:
28
+ expression.resolve(state)
29
+
30
+ # THEN an exception should be raised
31
+ assert "Expected a string, but got 123 of type <class 'int'>" == str(exc_info.value)
@@ -318,9 +318,12 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
318
318
  original_base = get_original_base(self.__class__)
319
319
 
320
320
  args = get_args(original_base)
321
- state_type = args[0]
322
321
 
323
- if isinstance(state_type, TypeVar):
322
+ if args and len(args) > 0:
323
+ state_type = args[0]
324
+ if isinstance(state_type, TypeVar):
325
+ state_type = BaseState
326
+ else:
324
327
  state_type = BaseState
325
328
 
326
329
  self.state = state_type()
@@ -1,7 +1,10 @@
1
+ import pytest
2
+
1
3
  from vellum import ExecuteApiResponse, VellumSecret as ClientVellumSecret
4
+ from vellum.client.core.api_error import ApiError
2
5
  from vellum.workflows.constants import APIRequestMethod, AuthorizationType
6
+ from vellum.workflows.exceptions import NodeException
3
7
  from vellum.workflows.nodes import APINode
4
- from vellum.workflows.state import BaseState
5
8
  from vellum.workflows.types.core import VellumSecret
6
9
 
7
10
 
@@ -25,10 +28,39 @@ def test_run_workflow__secrets(vellum_client):
25
28
  }
26
29
  bearer_token_value = VellumSecret(name="secret")
27
30
 
28
- node = SimpleBaseAPINode(state=BaseState())
31
+ node = SimpleBaseAPINode()
29
32
  terminal = node.run()
30
33
 
31
34
  assert vellum_client.execute_api.call_count == 1
32
35
  bearer_token = vellum_client.execute_api.call_args.kwargs["bearer_token"]
33
36
  assert bearer_token == ClientVellumSecret(name="secret")
34
37
  assert terminal.headers == {"X-Response-Header": "bar"}
38
+
39
+
40
+ def test_api_node_raises_error_when_api_call_fails(vellum_client):
41
+ # GIVEN an API call that fails
42
+ vellum_client.execute_api.side_effect = ApiError(status_code=400, body="API Error")
43
+
44
+ class SimpleAPINode(APINode):
45
+ method = APIRequestMethod.GET
46
+ authorization_type = AuthorizationType.BEARER_TOKEN
47
+ url = "https://api.vellum.ai"
48
+ body = {
49
+ "key": "value",
50
+ }
51
+ headers = {
52
+ "X-Test-Header": "foo",
53
+ }
54
+ bearer_token_value = VellumSecret(name="api_key")
55
+
56
+ node = SimpleAPINode()
57
+
58
+ # WHEN we run the node
59
+ with pytest.raises(NodeException) as excinfo:
60
+ node.run()
61
+
62
+ # THEN an exception should be raised
63
+ assert "Failed to prepare HTTP request" in str(excinfo.value)
64
+
65
+ # AND the API call should have been made
66
+ assert vellum_client.execute_api.call_count == 1
@@ -89,7 +89,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
89
89
  url=url, method=method.value, body=data, headers=headers, bearer_token=client_vellum_secret
90
90
  )
91
91
  except ApiError as e:
92
- NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.NODE_EXECUTION)
92
+ raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.NODE_EXECUTION)
93
93
 
94
94
  return self.Outputs(
95
95
  json=vellum_response.json_,
@@ -19,6 +19,7 @@ from vellum import (
19
19
  VellumError,
20
20
  VellumValue,
21
21
  )
22
+ from vellum.client.core.api_error import ApiError
22
23
  from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
23
24
  from vellum.core import RequestOptions
24
25
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -103,14 +104,23 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
103
104
  input_values = self._compile_code_inputs()
104
105
  expected_output_type = primitive_type_to_vellum_variable_type(output_type)
105
106
 
106
- code_execution_result = self._context.vellum_client.execute_code(
107
- input_values=input_values,
108
- code=code,
109
- runtime=self.runtime,
110
- output_type=expected_output_type,
111
- packages=self.packages or [],
112
- request_options=self.request_options,
113
- )
107
+ try:
108
+ code_execution_result = self._context.vellum_client.execute_code(
109
+ input_values=input_values,
110
+ code=code,
111
+ runtime=self.runtime,
112
+ output_type=expected_output_type,
113
+ packages=self.packages or [],
114
+ request_options=self.request_options,
115
+ )
116
+ except ApiError as e:
117
+ if e.status_code == 400 and isinstance(e.body, dict) and "message" in e.body:
118
+ raise NodeException(
119
+ message=e.body["message"],
120
+ code=WorkflowErrorCode.INVALID_INPUTS,
121
+ )
122
+
123
+ raise
114
124
 
115
125
  if code_execution_result.output.type != expected_output_type:
116
126
  actual_type = code_execution_result.output.type
@@ -5,6 +5,7 @@ from typing import Any, List, Union
5
5
  from pydantic import BaseModel
6
6
 
7
7
  from vellum import CodeExecutorResponse, NumberVellumValue, StringInput, StringVellumValue
8
+ from vellum.client.errors.bad_request_error import BadRequestError
8
9
  from vellum.client.types.chat_message import ChatMessage
9
10
  from vellum.client.types.code_execution_package import CodeExecutionPackage
10
11
  from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
@@ -17,6 +18,7 @@ from vellum.workflows.inputs.base import BaseInputs
17
18
  from vellum.workflows.nodes.displayable.code_execution_node import CodeExecutionNode
18
19
  from vellum.workflows.references.vellum_secret import VellumSecretReference
19
20
  from vellum.workflows.state.base import BaseState, StateMeta
21
+ from vellum.workflows.types.core import Json
20
22
 
21
23
 
22
24
  def test_run_node__happy_path(vellum_client):
@@ -690,3 +692,54 @@ def main():
690
692
  "result": [ChatMessage(role="USER", content=StringChatMessageContent(value="Hello, world!"))],
691
693
  "log": "",
692
694
  }
695
+
696
+
697
+ def test_run_node__execute_code_api_fails__node_exception(vellum_client):
698
+ # GIVEN a node that will throw a JSON.parse error
699
+ class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, Json]):
700
+ code = """\
701
+ async function main(inputs: {
702
+ data: string,
703
+ }): Promise<string> {
704
+ return JSON.parse(inputs.data)
705
+ }
706
+ """
707
+ code_inputs = {
708
+ "data": "not a valid json string",
709
+ }
710
+ runtime = "TYPESCRIPT_5_3_3"
711
+
712
+ # AND the execute_code API will fail
713
+ message = """\
714
+ Code execution error (exit code 1): undefined:1
715
+ not a valid json string
716
+ ^
717
+
718
+ SyntaxError: Unexpected token 'o', \"not a valid\"... is not valid JSON
719
+ at JSON.parse (<anonymous>)
720
+ at Object.eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:40:40)
721
+ at step (eval at execute (/workdir/runner.js:16:18), <anonymous>:32:23)
722
+ at Object.eval [as next] (eval at execute (/workdir/runner.js:16:18), <anonymous>:13:53)
723
+ at eval (eval at execute (/workdir/runner.js:16:18), <anonymous>:7:71)
724
+ at new Promise (<anonymous>)
725
+ at __awaiter (eval at execute (/workdir/runner.js:16:18), <anonymous>:3:12)
726
+ at Object.main (eval at execute (/workdir/runner.js:16:18), <anonymous>:38:12)
727
+ at execute (/workdir/runner.js:17:33)
728
+ at Interface.<anonymous> (/workdir/runner.js:58:5)
729
+
730
+ Node.js v21.7.3
731
+ """
732
+ vellum_client.execute_code.side_effect = BadRequestError(
733
+ body={
734
+ "message": message,
735
+ "log": "",
736
+ }
737
+ )
738
+
739
+ # WHEN we run the node
740
+ node = ExampleCodeExecutionNode()
741
+ with pytest.raises(NodeException) as exc_info:
742
+ node.run()
743
+
744
+ # AND the error should contain the execution error details
745
+ assert exc_info.value.message == message
@@ -1,5 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from copy import deepcopy
3
+ from dataclasses import dataclass
3
4
  import logging
4
5
  from queue import Empty, Queue
5
6
  from threading import Event as ThreadingEvent, Thread
@@ -63,6 +64,12 @@ ExternalInputsArg = Dict[ExternalInputReference, Any]
63
64
  BackgroundThreadItem = Union[BaseState, WorkflowEvent, None]
64
65
 
65
66
 
67
+ @dataclass
68
+ class ActiveNode(Generic[StateType]):
69
+ node: BaseNode[StateType]
70
+ was_outputs_streamed: bool = False
71
+
72
+
66
73
  class WorkflowRunner(Generic[StateType]):
67
74
  _entrypoints: Iterable[Type[BaseNode]]
68
75
 
@@ -136,7 +143,7 @@ class WorkflowRunner(Generic[StateType]):
136
143
  self._dependencies: Dict[Type[BaseNode], Set[Type[BaseNode]]] = defaultdict(set)
137
144
  self._state_forks: Set[StateType] = {self._initial_state}
138
145
 
139
- self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
146
+ self._active_nodes_by_execution_id: Dict[UUID, ActiveNode[StateType]] = {}
140
147
  self._cancel_signal = cancel_signal
141
148
  self._execution_context = init_execution_context or get_execution_context()
142
149
  self._parent_context = self._execution_context.parent_context
@@ -404,7 +411,7 @@ class WorkflowRunner(Generic[StateType]):
404
411
  current_parent = get_parent_context()
405
412
  node = node_class(state=state, context=self.workflow.context)
406
413
  state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
407
- self._active_nodes_by_execution_id[node_span_id] = node
414
+ self._active_nodes_by_execution_id[node_span_id] = ActiveNode(node=node)
408
415
 
409
416
  worker_thread = Thread(
410
417
  target=self._context_run_work_item,
@@ -413,10 +420,11 @@ class WorkflowRunner(Generic[StateType]):
413
420
  worker_thread.start()
414
421
 
415
422
  def _handle_work_item_event(self, event: WorkflowEvent) -> Optional[WorkflowError]:
416
- node = self._active_nodes_by_execution_id.get(event.span_id)
417
- if not node:
423
+ active_node = self._active_nodes_by_execution_id.get(event.span_id)
424
+ if not active_node:
418
425
  return None
419
426
 
427
+ node = active_node.node
420
428
  if event.name == "node.execution.rejected":
421
429
  self._active_nodes_by_execution_id.pop(event.span_id)
422
430
  return event.error
@@ -431,6 +439,7 @@ class WorkflowRunner(Generic[StateType]):
431
439
  if node_output_descriptor.name != event.output.name:
432
440
  continue
433
441
 
442
+ active_node.was_outputs_streamed = True
434
443
  self._workflow_event_outer_queue.put(
435
444
  self._stream_workflow_event(
436
445
  BaseOutput(
@@ -447,6 +456,26 @@ class WorkflowRunner(Generic[StateType]):
447
456
 
448
457
  if event.name == "node.execution.fulfilled":
449
458
  self._active_nodes_by_execution_id.pop(event.span_id)
459
+ if not active_node.was_outputs_streamed:
460
+ for event_node_output_descriptor, node_output_value in event.outputs:
461
+ for workflow_output_descriptor in self.workflow.Outputs:
462
+ node_output_descriptor = workflow_output_descriptor.instance
463
+ if not isinstance(node_output_descriptor, OutputReference):
464
+ continue
465
+ if node_output_descriptor.outputs_class != event.node_definition.Outputs:
466
+ continue
467
+ if node_output_descriptor.name != event_node_output_descriptor.name:
468
+ continue
469
+
470
+ self._workflow_event_outer_queue.put(
471
+ self._stream_workflow_event(
472
+ BaseOutput(
473
+ name=workflow_output_descriptor.name,
474
+ value=node_output_value,
475
+ )
476
+ )
477
+ )
478
+
450
479
  self._handle_invoked_ports(node.state, event.invoked_ports)
451
480
 
452
481
  return None
@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Type
8
8
 
9
9
  from pydantic import BaseModel
10
10
 
11
+ from vellum.workflows.constants import undefined
11
12
  from vellum.workflows.inputs.base import BaseInputs
12
13
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
13
14
  from vellum.workflows.ports.port import Port
@@ -22,7 +23,7 @@ class DefaultStateEncoder(JSONEncoder):
22
23
  return dict(obj)
23
24
 
24
25
  if isinstance(obj, (BaseInputs, BaseOutputs)):
25
- return {descriptor.name: value for descriptor, value in obj}
26
+ return {descriptor.name: value for descriptor, value in obj if value is not undefined}
26
27
 
27
28
  if isinstance(obj, (BaseOutput, Port)):
28
29
  return obj.serialize()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 0.14.7
3
+ Version: 0.14.9
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0