vellum-ai 0.13.28__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/constants.py +8 -3
  3. vellum/workflows/descriptors/tests/test_utils.py +21 -0
  4. vellum/workflows/descriptors/utils.py +3 -3
  5. vellum/workflows/errors/types.py +4 -1
  6. vellum/workflows/expressions/coalesce_expression.py +2 -2
  7. vellum/workflows/expressions/contains.py +4 -3
  8. vellum/workflows/expressions/does_not_contain.py +2 -1
  9. vellum/workflows/expressions/is_nil.py +2 -2
  10. vellum/workflows/expressions/is_not_nil.py +2 -2
  11. vellum/workflows/expressions/is_not_undefined.py +2 -2
  12. vellum/workflows/expressions/is_undefined.py +2 -2
  13. vellum/workflows/nodes/bases/base.py +19 -3
  14. vellum/workflows/nodes/bases/tests/test_base_node.py +84 -0
  15. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +3 -3
  16. vellum/workflows/nodes/core/map_node/node.py +5 -0
  17. vellum/workflows/nodes/core/map_node/tests/test_node.py +22 -0
  18. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +68 -2
  19. vellum/workflows/nodes/displayable/code_execution_node/utils.py +30 -7
  20. vellum/workflows/outputs/base.py +21 -19
  21. vellum/workflows/references/external_input.py +2 -2
  22. vellum/workflows/references/lazy.py +2 -2
  23. vellum/workflows/references/output.py +7 -7
  24. vellum/workflows/runner/runner.py +20 -15
  25. vellum/workflows/state/base.py +2 -2
  26. vellum/workflows/state/tests/test_state.py +7 -11
  27. vellum/workflows/workflows/base.py +20 -0
  28. vellum/workflows/workflows/tests/__init__.py +0 -0
  29. vellum/workflows/workflows/tests/test_base_workflow.py +80 -0
  30. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.0.dist-info}/METADATA +1 -1
  31. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.0.dist-info}/RECORD +35 -33
  32. vellum_ee/workflows/display/nodes/base_node_display.py +2 -2
  33. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.0.dist-info}/LICENSE +0 -0
  34. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.0.dist-info}/WHEEL +0 -0
  35. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.0.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.13.28",
21
+ "X-Fern-SDK-Version": "0.14.0",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -4,11 +4,11 @@ from typing import Any, cast
4
4
 
5
5
  class _UndefMeta(type):
6
6
  def __repr__(cls) -> str:
7
- return "UNDEF"
7
+ return "undefined"
8
8
 
9
9
  def __getattribute__(cls, name: str) -> Any:
10
10
  if name == "__class__":
11
- # ensures that UNDEF.__class__ == UNDEF
11
+ # ensures that undefined.__class__ == undefined
12
12
  return cls
13
13
 
14
14
  return super().__getattribute__(name)
@@ -17,7 +17,12 @@ class _UndefMeta(type):
17
17
  return False
18
18
 
19
19
 
20
- class UNDEF(metaclass=_UndefMeta):
20
+ class undefined(metaclass=_UndefMeta):
21
+ """
22
+ A singleton class that represents an `undefined` value, mirroring the behavior of the `undefined`
23
+ value in TypeScript.
24
+ """
25
+
21
26
  pass
22
27
 
23
28
 
@@ -1,6 +1,7 @@
1
1
  import pytest
2
2
 
3
3
  from vellum.workflows.descriptors.utils import resolve_value
4
+ from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
4
5
  from vellum.workflows.nodes.bases.base import BaseNode
5
6
  from vellum.workflows.references.constant import ConstantValueReference
6
7
  from vellum.workflows.state.base import BaseState
@@ -77,6 +78,24 @@ class DummyNode(BaseNode[FixtureState]):
77
78
  (FixtureState.zeta["foo"], "bar"),
78
79
  (ConstantValueReference(1), 1),
79
80
  (FixtureState.theta[0], "baz"),
81
+ (
82
+ ConstantValueReference(
83
+ WorkflowError(
84
+ message="This is a test",
85
+ code=WorkflowErrorCode.USER_DEFINED_ERROR,
86
+ )
87
+ ).contains("test"),
88
+ True,
89
+ ),
90
+ (
91
+ ConstantValueReference(
92
+ WorkflowError(
93
+ message="This is a test",
94
+ code=WorkflowErrorCode.USER_DEFINED_ERROR,
95
+ )
96
+ ).does_not_contain("test"),
97
+ False,
98
+ ),
80
99
  ],
81
100
  ids=[
82
101
  "or",
@@ -122,6 +141,8 @@ class DummyNode(BaseNode[FixtureState]):
122
141
  "accessor",
123
142
  "constants",
124
143
  "list_index",
144
+ "error_contains",
145
+ "error_does_not_contain",
125
146
  ],
126
147
  )
127
148
  def test_resolve_value__happy_path(descriptor, expected_value):
@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, ove
5
5
 
6
6
  from pydantic import BaseModel
7
7
 
8
- from vellum.workflows.constants import UNDEF
8
+ from vellum.workflows.constants import undefined
9
9
  from vellum.workflows.descriptors.base import BaseDescriptor
10
10
  from vellum.workflows.state.base import BaseState
11
11
 
@@ -93,10 +93,10 @@ def resolve_value(
93
93
 
94
94
  def is_unresolved(value: Any) -> bool:
95
95
  """
96
- Recursively checks if a value has an unresolved value, represented by UNDEF.
96
+ Recursively checks if a value has an unresolved value, represented by undefined.
97
97
  """
98
98
 
99
- if value is UNDEF:
99
+ if value is undefined:
100
100
  return True
101
101
 
102
102
  if dataclasses.is_dataclass(value):
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from enum import Enum
3
- from typing import Dict
3
+ from typing import Any, Dict
4
4
 
5
5
  from vellum.client.types.vellum_error import VellumError
6
6
  from vellum.client.types.vellum_error_code_enum import VellumErrorCodeEnum
@@ -26,6 +26,9 @@ class WorkflowError:
26
26
  message: str
27
27
  code: WorkflowErrorCode
28
28
 
29
+ def __contains__(self, item: Any) -> bool:
30
+ return item in self.message
31
+
29
32
 
30
33
  _VELLUM_ERROR_CODE_TO_WORKFLOW_ERROR_CODE: Dict[VellumErrorCodeEnum, WorkflowErrorCode] = {
31
34
  "INVALID_REQUEST": WorkflowErrorCode.INVALID_INPUTS,
@@ -1,6 +1,6 @@
1
1
  from typing import TypeVar, Union
2
2
 
3
- from vellum.workflows.constants import UNDEF
3
+ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.descriptors.utils import resolve_value
6
6
  from vellum.workflows.state.base import BaseState
@@ -27,7 +27,7 @@ class CoalesceExpression(BaseDescriptor[Union[LHS, RHS]]):
27
27
 
28
28
  def resolve(self, state: "BaseState") -> Union[LHS, RHS]:
29
29
  lhs = resolve_value(self._lhs, state)
30
- if lhs is not UNDEF and lhs is not None:
30
+ if lhs is not undefined and lhs is not None:
31
31
  return lhs
32
32
 
33
33
  return resolve_value(self._rhs, state)
@@ -1,9 +1,10 @@
1
1
  from typing import Generic, TypeVar, Union
2
2
 
3
- from vellum.workflows.constants import UNDEF
3
+ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.descriptors.exceptions import InvalidExpressionException
6
6
  from vellum.workflows.descriptors.utils import resolve_value
7
+ from vellum.workflows.errors.types import WorkflowError
7
8
  from vellum.workflows.state.base import BaseState
8
9
 
9
10
  LHS = TypeVar("LHS")
@@ -26,9 +27,9 @@ class ContainsExpression(BaseDescriptor[bool], Generic[LHS, RHS]):
26
27
  # https://app.shortcut.com/vellum/story/4658
27
28
  lhs = resolve_value(self._lhs, state)
28
29
  # assumes that lack of is also false
29
- if lhs is UNDEF:
30
+ if lhs is undefined:
30
31
  return False
31
- if not isinstance(lhs, (list, tuple, set, dict, str)):
32
+ if not isinstance(lhs, (list, tuple, set, dict, str, WorkflowError)):
32
33
  raise InvalidExpressionException(
33
34
  f"Expected a LHS that supported `contains`, got `{lhs.__class__.__name__}`"
34
35
  )
@@ -3,6 +3,7 @@ from typing import Generic, TypeVar, Union
3
3
  from vellum.workflows.descriptors.base import BaseDescriptor
4
4
  from vellum.workflows.descriptors.exceptions import InvalidExpressionException
5
5
  from vellum.workflows.descriptors.utils import resolve_value
6
+ from vellum.workflows.errors.types import WorkflowError
6
7
  from vellum.workflows.state.base import BaseState
7
8
 
8
9
  LHS = TypeVar("LHS")
@@ -24,7 +25,7 @@ class DoesNotContainExpression(BaseDescriptor[bool], Generic[LHS, RHS]):
24
25
  # Support any type that implements the not in operator
25
26
  # https://app.shortcut.com/vellum/story/4658
26
27
  lhs = resolve_value(self._lhs, state)
27
- if not isinstance(lhs, (list, tuple, set, dict, str)):
28
+ if not isinstance(lhs, (list, tuple, set, dict, str, WorkflowError)):
28
29
  raise InvalidExpressionException(
29
30
  f"Expected a LHS that supported `contains`, got `{lhs.__class__.__name__}`"
30
31
  )
@@ -1,6 +1,6 @@
1
1
  from typing import Generic, TypeVar, Union
2
2
 
3
- from vellum.workflows.constants import UNDEF
3
+ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.descriptors.utils import resolve_value
6
6
  from vellum.workflows.state.base import BaseState
@@ -19,4 +19,4 @@ class IsNilExpression(BaseDescriptor[bool], Generic[_T]):
19
19
 
20
20
  def resolve(self, state: "BaseState") -> bool:
21
21
  expression = resolve_value(self._expression, state)
22
- return expression is None or expression is UNDEF
22
+ return expression is None or expression is undefined
@@ -1,6 +1,6 @@
1
1
  from typing import Generic, TypeVar, Union
2
2
 
3
- from vellum.workflows.constants import UNDEF
3
+ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.descriptors.utils import resolve_value
6
6
  from vellum.workflows.state.base import BaseState
@@ -19,4 +19,4 @@ class IsNotNilExpression(BaseDescriptor[bool], Generic[_T]):
19
19
 
20
20
  def resolve(self, state: "BaseState") -> bool:
21
21
  expression = resolve_value(self._expression, state)
22
- return expression is not None and expression is not UNDEF
22
+ return expression is not None and expression is not undefined
@@ -1,6 +1,6 @@
1
1
  from typing import Generic, TypeVar, Union
2
2
 
3
- from vellum.workflows.constants import UNDEF
3
+ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.descriptors.utils import resolve_value
6
6
  from vellum.workflows.state.base import BaseState
@@ -19,4 +19,4 @@ class IsNotUndefinedExpression(BaseDescriptor[bool], Generic[_T]):
19
19
 
20
20
  def resolve(self, state: "BaseState") -> bool:
21
21
  expression = resolve_value(self._expression, state)
22
- return expression is not UNDEF
22
+ return expression is not undefined
@@ -1,6 +1,6 @@
1
1
  from typing import Generic, TypeVar, Union
2
2
 
3
- from vellum.workflows.constants import UNDEF
3
+ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.descriptors.utils import resolve_value
6
6
  from vellum.workflows.state.base import BaseState
@@ -19,4 +19,4 @@ class IsUndefinedExpression(BaseDescriptor[bool], Generic[_T]):
19
19
 
20
20
  def resolve(self, state: "BaseState") -> bool:
21
21
  expression = resolve_value(self._expression, state)
22
- return expression is UNDEF
22
+ return expression is undefined
@@ -5,7 +5,7 @@ from types import MappingProxyType
5
5
  from uuid import UUID
6
6
  from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union, cast, get_args
7
7
 
8
- from vellum.workflows.constants import UNDEF
8
+ from vellum.workflows.constants import undefined
9
9
  from vellum.workflows.descriptors.base import BaseDescriptor
10
10
  from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
11
11
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -43,7 +43,23 @@ class BaseNodeMeta(type):
43
43
  # TODO: Inherit the inner Output classes from every base class.
44
44
  # https://app.shortcut.com/vellum/story/4007/support-auto-inheriting-parent-node-outputs
45
45
 
46
- if "Outputs" not in dct:
46
+ if "Outputs" in dct:
47
+ outputs_class = dct["Outputs"]
48
+ if not any(issubclass(base, BaseOutputs) for base in outputs_class.__bases__):
49
+ parent_outputs_class = next(
50
+ (base.Outputs for base in bases if hasattr(base, "Outputs")),
51
+ BaseOutputs, # Default to BaseOutputs only if no parent has Outputs
52
+ )
53
+
54
+ # Filter out object from bases while preserving other inheritance
55
+ filtered_bases = tuple(base for base in outputs_class.__bases__ if base is not object)
56
+
57
+ dct["Outputs"] = type(
58
+ f"{name}.Outputs",
59
+ (parent_outputs_class,) + filtered_bases,
60
+ {**outputs_class.__dict__, "__module__": dct["__module__"]},
61
+ )
62
+ else:
47
63
  for base in reversed(bases):
48
64
  if hasattr(base, "Outputs"):
49
65
  dct["Outputs"] = type(
@@ -165,7 +181,7 @@ class BaseNodeMeta(type):
165
181
  if attr_name in yielded_attr_names:
166
182
  continue
167
183
 
168
- attr_value = getattr(resolved_cls, attr_name, UNDEF)
184
+ attr_value = getattr(resolved_cls, attr_name, undefined)
169
185
  if not isinstance(attr_value, NodeReference):
170
186
  continue
171
187
 
@@ -5,6 +5,7 @@ from vellum.client.types.string_vellum_value_request import StringVellumValueReq
5
5
  from vellum.core.pydantic_utilities import UniversalBaseModel
6
6
  from vellum.workflows.inputs.base import BaseInputs
7
7
  from vellum.workflows.nodes.bases.base import BaseNode
8
+ from vellum.workflows.outputs.base import BaseOutputs
8
9
  from vellum.workflows.state.base import BaseState, StateMeta
9
10
 
10
11
 
@@ -148,3 +149,86 @@ def test_base_node__node_resolution__descriptor_in_fern_pydantic():
148
149
  node = SomeNode(state=State(foo="bar"))
149
150
 
150
151
  assert node.model.value == "bar"
152
+
153
+
154
+ def test_base_node__inherit_base_outputs():
155
+ class MyNode(BaseNode):
156
+ class Outputs:
157
+ foo: str
158
+
159
+ def run(self):
160
+ return self.Outputs(foo="bar") # type: ignore
161
+
162
+ # TEST that the Outputs class is a subclass of BaseOutputs
163
+ assert issubclass(MyNode.Outputs, BaseOutputs)
164
+
165
+ # TEST that the Outputs class does not inherit from object
166
+ assert object not in MyNode.Outputs.__bases__
167
+
168
+ # TEST that the Outputs class has the correct attributes
169
+ assert hasattr(MyNode.Outputs, "foo")
170
+
171
+ # WHEN the node is run
172
+ node = MyNode()
173
+ outputs = node.run()
174
+
175
+ # THEN the outputs should be correct
176
+ assert outputs.foo == "bar"
177
+
178
+
179
+ def test_child_node__inherits_base_outputs_when_no_parent_outputs():
180
+ class ParentNode(BaseNode): # No Outputs class here
181
+ pass
182
+
183
+ class ChildNode(ParentNode):
184
+ class Outputs:
185
+ foo: str
186
+
187
+ def run(self):
188
+ return self.Outputs(foo="bar") # type: ignore
189
+
190
+ # TEST that ChildNode.Outputs is a subclass of BaseOutputs (since ParentNode has no Outputs)
191
+ assert issubclass(ChildNode.Outputs, BaseOutputs)
192
+
193
+ # TEST that ChildNode.Outputs has the correct attributes
194
+ assert hasattr(ChildNode.Outputs, "foo")
195
+
196
+ # WHEN the node is run
197
+ node = ChildNode()
198
+ outputs = node.run()
199
+
200
+ # THEN the outputs should be correct
201
+ assert outputs.foo == "bar"
202
+
203
+
204
+ def test_outputs_preserves_non_object_bases():
205
+ class ParentNode(BaseNode):
206
+ class Outputs:
207
+ foo: str
208
+
209
+ class Foo:
210
+ bar: str
211
+
212
+ class ChildNode(ParentNode):
213
+ class Outputs(ParentNode.Outputs, Foo):
214
+ pass
215
+
216
+ def run(self):
217
+ return self.Outputs(foo="bar", bar="baz") # type: ignore
218
+
219
+ # TEST that Outputs is a subclass of Foo and ParentNode.Outputs
220
+ assert Foo in ChildNode.Outputs.__bases__, "Foo should be preserved in bases"
221
+ assert ParentNode.Outputs in ChildNode.Outputs.__bases__, "ParentNode.Outputs should be preserved in bases"
222
+ assert object not in ChildNode.Outputs.__bases__, "object should not be in bases"
223
+
224
+ # TEST that Outputs has the correct attributes
225
+ assert hasattr(ChildNode.Outputs, "foo")
226
+ assert hasattr(ChildNode.Outputs, "bar")
227
+
228
+ # WHEN Outputs is instantiated
229
+ node = ChildNode()
230
+ outputs = node.run()
231
+
232
+ # THEN the output values should be correct
233
+ assert outputs.foo == "bar"
234
+ assert outputs.bar == "baz"
@@ -1,6 +1,6 @@
1
1
  from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union
2
2
 
3
- from vellum.workflows.constants import UNDEF
3
+ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.context import execution_context, get_parent_context
5
5
  from vellum.workflows.errors.types import WorkflowErrorCode
6
6
  from vellum.workflows.exceptions import NodeException
@@ -67,7 +67,7 @@ class InlineSubworkflowNode(
67
67
  """
68
68
 
69
69
  subworkflow: Type["BaseWorkflow[InputsType, InnerStateType]"]
70
- subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs, Type[UNDEF]]] = UNDEF
70
+ subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs, Type[undefined]]] = undefined
71
71
 
72
72
  def run(self) -> Iterator[BaseOutput]:
73
73
  with execution_context(parent_context=get_parent_context() or self._context.parent_context):
@@ -112,7 +112,7 @@ class InlineSubworkflowNode(
112
112
 
113
113
  def _compile_subworkflow_inputs(self) -> InputsType:
114
114
  inputs_class = self.subworkflow.get_inputs_class()
115
- if self.subworkflow_inputs is UNDEF:
115
+ if self.subworkflow_inputs is undefined:
116
116
  inputs_dict = {}
117
117
  for descriptor in inputs_class:
118
118
  if hasattr(self, descriptor.name):
@@ -66,6 +66,11 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
66
66
  for output_descripter in self.subworkflow.Outputs:
67
67
  mapped_items[output_descripter.name] = [None] * len(self.items)
68
68
 
69
+ if not self.items:
70
+ for output_name, output_list in mapped_items.items():
71
+ yield BaseOutput(name=output_name, value=output_list)
72
+ return
73
+
69
74
  self._event_queue: Queue[Tuple[int, WorkflowEvent]] = Queue()
70
75
  self._concurrency_queue: Queue[Thread] = Queue()
71
76
  fulfilled_iterations: List[bool] = []
@@ -63,3 +63,25 @@ def test_map_node__use_parallelism():
63
63
  # THEN the node should have ran in parallel
64
64
  run_time = (end_ts - start_ts) / 10**9
65
65
  assert run_time < 0.2
66
+
67
+
68
+ def test_map_node__empty_list():
69
+ # GIVEN a map node that is configured to use the parent's inputs and state
70
+ @MapNode.wrap(items=[])
71
+ class TestNode(BaseNode):
72
+ item = MapNode.SubworkflowInputs.item
73
+
74
+ class Outputs(BaseOutputs):
75
+ value: int
76
+
77
+ def run(self) -> Outputs:
78
+ time.sleep(0.03)
79
+ return self.Outputs(value=self.item + 1)
80
+
81
+ # WHEN the node is run
82
+ node = TestNode()
83
+ outputs = list(node.run())
84
+
85
+ # THEN the node should return an empty output
86
+ fulfilled_output = outputs[-1]
87
+ assert fulfilled_output == BaseOutput(name="value", value=[])
@@ -1,8 +1,8 @@
1
1
  import pytest
2
2
  import os
3
- from typing import Any
3
+ from typing import Any, Union
4
4
 
5
- from vellum import CodeExecutorResponse, NumberVellumValue, StringInput
5
+ from vellum import CodeExecutorResponse, NumberVellumValue, StringInput, StringVellumValue
6
6
  from vellum.client.types.code_execution_package import CodeExecutionPackage
7
7
  from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
8
8
  from vellum.client.types.function_call import FunctionCall
@@ -493,3 +493,69 @@ def main(word: str) -> dict:
493
493
  },
494
494
  "log": "",
495
495
  }
496
+
497
+
498
+ def test_run_node__array_input_with_vellum_values(vellum_client):
499
+ """Confirm that CodeExecutionNodes can handle arrays containing VellumValue objects."""
500
+
501
+ # GIVEN a node that subclasses CodeExecutionNode that processes an array of VellumValues
502
+ class State(BaseState):
503
+ pass
504
+
505
+ class ExampleCodeExecutionNode(CodeExecutionNode[State, str]):
506
+ code = """\
507
+ from typing import List, Dict
508
+ def main(arg1: List[Dict]) -> str:
509
+ return arg1[0]["value"] + " " + arg1[1]["value"]
510
+ """
511
+ runtime = "PYTHON_3_11_6"
512
+
513
+ code_inputs = {
514
+ "arg1": [
515
+ StringVellumValue(type="STRING", value="Hello", name="First"),
516
+ StringVellumValue(type="STRING", value="World", name="Second"),
517
+ ],
518
+ }
519
+
520
+ # WHEN we run the node
521
+ node = ExampleCodeExecutionNode(state=State())
522
+ outputs = node.run()
523
+
524
+ # THEN the node should successfully concatenate the values
525
+ assert outputs == {"result": "Hello World", "log": ""}
526
+
527
+ # AND we should not have invoked the Code via Vellum since it's running inline
528
+ vellum_client.execute_code.assert_not_called()
529
+
530
+
531
+ def test_run_node__union_output_type(vellum_client):
532
+ """Confirm that CodeExecutionNodes can handle Union output types."""
533
+
534
+ # GIVEN a node that subclasses CodeExecutionNode that returns a Union type
535
+ class State(BaseState):
536
+ pass
537
+
538
+ class ExampleCodeExecutionNode(CodeExecutionNode[State, Union[float, int]]):
539
+ code = """\
540
+ from typing import List, Dict
541
+ def main(arg1: List[Dict]) -> float:
542
+ return arg1[0]["value"] + arg1[1]["value"]
543
+ """
544
+ runtime = "PYTHON_3_11_6"
545
+
546
+ code_inputs = {
547
+ "arg1": [
548
+ NumberVellumValue(type="NUMBER", value=1.0, name="First"),
549
+ NumberVellumValue(type="NUMBER", value=2.0, name="Second"),
550
+ ],
551
+ }
552
+
553
+ # WHEN we run the node
554
+ node = ExampleCodeExecutionNode(state=State())
555
+ outputs = node.run()
556
+
557
+ # THEN the node should successfully sum the values
558
+ assert outputs == {"result": 3.0, "log": ""}
559
+
560
+ # AND we should not have invoked the Code via Vellum since it's running inline
561
+ vellum_client.execute_code.assert_not_called()
@@ -1,10 +1,11 @@
1
1
  import io
2
2
  import os
3
3
  import re
4
- from typing import Any, List, Tuple, Union
4
+ from typing import Any, List, Tuple, Union, get_args, get_origin
5
5
 
6
6
  from pydantic import BaseModel, ValidationError
7
7
 
8
+ from vellum import VellumValue
8
9
  from vellum.client.types.code_executor_input import CodeExecutorInput
9
10
  from vellum.workflows.errors.types import WorkflowErrorCode
10
11
  from vellum.workflows.exceptions import NodeException
@@ -74,8 +75,25 @@ def run_code_inline(
74
75
  ) -> Tuple[str, Any]:
75
76
  log_buffer = io.StringIO()
76
77
 
78
+ VELLUM_TYPES = get_args(VellumValue)
79
+
80
+ def wrap_value(value):
81
+ if isinstance(value, list):
82
+ return ListWrapper(
83
+ [
84
+ # Convert VellumValue to dict with its fields
85
+ (
86
+ item.model_dump()
87
+ if isinstance(item, VELLUM_TYPES)
88
+ else _clean_for_dict_wrapper(item) if isinstance(item, (dict, list)) else item
89
+ )
90
+ for item in value
91
+ ]
92
+ )
93
+ return _clean_for_dict_wrapper(value)
94
+
77
95
  exec_globals = {
78
- "__arg__inputs": {input_value.name: _clean_for_dict_wrapper(input_value.value) for input_value in input_values},
96
+ "__arg__inputs": {input_value.name: wrap_value(input_value.value) for input_value in input_values},
79
97
  "__arg__out": None,
80
98
  "print": lambda *args, **kwargs: log_buffer.write(f"{' '.join(args)}\n"),
81
99
  }
@@ -92,7 +110,14 @@ __arg__out = main({", ".join(run_args)})
92
110
  result = exec_globals["__arg__out"]
93
111
 
94
112
  if output_type != Any:
95
- if issubclass(output_type, BaseModel) and not isinstance(result, output_type):
113
+ if get_origin(output_type) is Union:
114
+ allowed_types = get_args(output_type)
115
+ if not isinstance(result, allowed_types):
116
+ raise NodeException(
117
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
118
+ message=f"Expected output to be in types {allowed_types}, but received '{type(result).__name__}'",
119
+ )
120
+ elif issubclass(output_type, BaseModel) and not isinstance(result, output_type):
96
121
  try:
97
122
  result = output_type.model_validate(result)
98
123
  except ValidationError as e:
@@ -100,12 +125,10 @@ __arg__out = main({", ".join(run_args)})
100
125
  code=WorkflowErrorCode.INVALID_OUTPUTS,
101
126
  message=re.sub(r"\s+For further information visit [^\s]+", "", str(e)),
102
127
  ) from e
103
-
104
- if not isinstance(result, output_type):
128
+ elif not isinstance(result, output_type):
105
129
  raise NodeException(
106
130
  code=WorkflowErrorCode.INVALID_OUTPUTS,
107
- message=f"Expected an output of type '{output_type.__name__}',"
108
- f" but received '{result.__class__.__name__}'",
131
+ message=f"Expected an output of type '{output_type.__name__}', but received '{type(result).__name__}'",
109
132
  )
110
133
 
111
134
  return logs, result