vellum-ai 0.14.0__py3-none-any.whl → 0.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/prompts/blocks/compilation.py +23 -16
  3. vellum/prompts/blocks/tests/test_compilation.py +29 -0
  4. vellum/utils/templating/constants.py +6 -2
  5. vellum/utils/templating/custom_filters.py +22 -1
  6. vellum/utils/templating/render.py +5 -2
  7. vellum/utils/templating/tests/__init__.py +0 -0
  8. vellum/utils/templating/tests/test_custom_filters.py +19 -0
  9. vellum/workflows/errors/types.py +3 -0
  10. vellum/workflows/nodes/core/templating_node/node.py +3 -3
  11. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +66 -1
  12. vellum/workflows/nodes/displayable/code_execution_node/node.py +5 -4
  13. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +51 -0
  14. vellum/workflows/nodes/displayable/code_execution_node/utils.py +13 -11
  15. vellum/workflows/nodes/utils.py +9 -1
  16. vellum/workflows/state/base.py +21 -1
  17. {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/METADATA +1 -1
  18. {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/RECORD +51 -46
  19. vellum_cli/pull.py +3 -11
  20. vellum_ee/workflows/display/base.py +14 -0
  21. vellum_ee/workflows/display/nodes/base_node_display.py +11 -22
  22. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +52 -0
  23. vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +1 -0
  24. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -1
  25. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -0
  26. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -0
  27. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -0
  28. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +243 -0
  29. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -0
  30. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +1 -0
  31. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +1 -1
  32. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
  33. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -0
  34. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -0
  35. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
  36. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +1 -1
  37. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -0
  38. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -0
  39. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -0
  40. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +1 -0
  41. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +1 -0
  42. vellum_ee/workflows/display/types.py +5 -1
  43. vellum_ee/workflows/display/utils/expressions.py +26 -0
  44. vellum_ee/workflows/display/utils/vellum.py +5 -0
  45. vellum_ee/workflows/display/vellum.py +15 -1
  46. vellum_ee/workflows/display/workflows/base_workflow_display.py +30 -1
  47. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +54 -6
  48. vellum_ee/workflows/tests/local_workflow/display/workflow.py +0 -2
  49. {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/LICENSE +0 -0
  50. {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/WHEEL +0 -0
  51. {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.0",
21
+ "X-Fern-SDK-Version": "0.14.2",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Optional, cast
2
+ from typing import Sequence, Union, cast
3
3
 
4
4
  from vellum import (
5
5
  ChatMessage,
@@ -14,6 +14,7 @@ from vellum.client.types.audio_vellum_value import AudioVellumValue
14
14
  from vellum.client.types.function_call import FunctionCall
15
15
  from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
16
16
  from vellum.client.types.image_vellum_value import ImageVellumValue
17
+ from vellum.client.types.number_input import NumberInput
17
18
  from vellum.client.types.vellum_audio import VellumAudio
18
19
  from vellum.client.types.vellum_image import VellumImage
19
20
  from vellum.prompts.blocks.exceptions import PromptCompilationError
@@ -22,15 +23,20 @@ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS
22
23
  from vellum.utils.templating.render import render_sandboxed_jinja_template
23
24
  from vellum.utils.typing import cast_not_optional
24
25
 
26
+ PromptInput = Union[PromptRequestInput, NumberInput]
27
+
25
28
 
26
29
  def compile_prompt_blocks(
27
30
  blocks: list[PromptBlock],
28
- inputs: list[PromptRequestInput],
31
+ inputs: Sequence[PromptInput],
29
32
  input_variables: list[VellumVariable],
30
33
  ) -> list[CompiledPromptBlock]:
31
34
  """Compiles a list of Prompt Blocks, performing all variable substitutions and Jinja templating needed."""
32
35
 
33
36
  sanitized_inputs = _sanitize_inputs(inputs)
37
+ inputs_by_name = {
38
+ (input_.name if isinstance(input_, NumberInput) else input_.key): input_ for input_ in sanitized_inputs
39
+ }
34
40
 
35
41
  compiled_blocks: list[CompiledPromptBlock] = []
36
42
  for block in blocks:
@@ -66,7 +72,7 @@ def compile_prompt_blocks(
66
72
 
67
73
  rendered_template = render_sandboxed_jinja_template(
68
74
  template=block.template,
69
- input_values={input_.key: input_.value for input_ in sanitized_inputs},
75
+ input_values={name: inp.value for name, inp in inputs_by_name.items()},
70
76
  jinja_custom_filters=DEFAULT_JINJA_CUSTOM_FILTERS,
71
77
  jinja_globals=DEFAULT_JINJA_CUSTOM_FILTERS,
72
78
  )
@@ -80,9 +86,7 @@ def compile_prompt_blocks(
80
86
  )
81
87
 
82
88
  elif block.block_type == "VARIABLE":
83
- compiled_input: Optional[PromptRequestInput] = next(
84
- (input_ for input_ in sanitized_inputs if input_.key == str(block.input_variable)), None
85
- )
89
+ compiled_input = inputs_by_name.get(block.input_variable)
86
90
  if compiled_input is None:
87
91
  raise PromptCompilationError(f"Input variable '{block.input_variable}' not found")
88
92
 
@@ -196,23 +200,26 @@ def _compile_chat_messages_as_prompt_blocks(chat_messages: list[ChatMessage]) ->
196
200
 
197
201
  def _compile_rich_text_block_as_value_block(
198
202
  block: RichTextPromptBlock,
199
- inputs: list[PromptRequestInput],
203
+ inputs: list[PromptInput],
200
204
  ) -> CompiledValuePromptBlock:
201
205
  value: str = ""
206
+ inputs_by_name = {(input_.name if isinstance(input_, NumberInput) else input_.key): input_ for input_ in inputs}
202
207
  for child_block in block.blocks:
203
208
  if child_block.block_type == "PLAIN_TEXT":
204
209
  value += child_block.text
205
210
  elif child_block.block_type == "VARIABLE":
206
- variable = next((input_ for input_ in inputs if input_.key == str(child_block.input_variable)), None)
207
- if variable is None:
211
+ input = inputs_by_name.get(child_block.input_variable)
212
+ if input is None:
208
213
  raise PromptCompilationError(f"Input variable '{child_block.input_variable}' not found")
209
- elif variable.type == "STRING":
210
- value += str(variable.value)
211
- elif variable.type == "JSON":
212
- value += json.dumps(variable.value, indent=4)
214
+ elif input.type == "STRING":
215
+ value += str(input.value)
216
+ elif input.type == "JSON":
217
+ value += json.dumps(input.value, indent=4)
218
+ elif input.type == "NUMBER":
219
+ value += str(input.value)
213
220
  else:
214
221
  raise PromptCompilationError(
215
- f"Input variable '{child_block.input_variable}' must be of type STRING or JSON"
222
+ f"Input variable '{child_block.input_variable}' has an invalid type: {input.type}"
216
223
  )
217
224
  else:
218
225
  raise PromptCompilationError(f"Invalid child block_type for RICH_TEXT: {child_block.block_type}")
@@ -220,8 +227,8 @@ def _compile_rich_text_block_as_value_block(
220
227
  return CompiledValuePromptBlock(content=StringVellumValue(value=value), cache_config=block.cache_config)
221
228
 
222
229
 
223
- def _sanitize_inputs(inputs: list[PromptRequestInput]) -> list[PromptRequestInput]:
224
- sanitized_inputs: list[PromptRequestInput] = []
230
+ def _sanitize_inputs(inputs: Sequence[PromptInput]) -> list[PromptInput]:
231
+ sanitized_inputs: list[PromptInput] = []
225
232
  for input_ in inputs:
226
233
  if input_.type == "CHAT_HISTORY" and input_.value is None:
227
234
  sanitized_inputs.append(input_.model_copy(update={"value": cast(list[ChatMessage], [])}))
@@ -10,6 +10,7 @@ from vellum import (
10
10
  VariablePromptBlock,
11
11
  VellumVariable,
12
12
  )
13
+ from vellum.client.types.number_input import NumberInput
13
14
  from vellum.prompts.blocks.compilation import compile_prompt_blocks
14
15
  from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
15
16
 
@@ -94,6 +95,33 @@ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, Compiled
94
95
  ),
95
96
  ],
96
97
  ),
98
+ (
99
+ [
100
+ ChatMessagePromptBlock(
101
+ chat_role="USER",
102
+ blocks=[
103
+ RichTextPromptBlock(
104
+ blocks=[
105
+ PlainTextPromptBlock(text="Count to "),
106
+ VariablePromptBlock(input_variable="count"),
107
+ ]
108
+ )
109
+ ],
110
+ )
111
+ ],
112
+ [
113
+ # TODO: We don't yet have PromptRequestNumberInput. We should either add it or migrate
114
+ # Prompts to using these more generic inputs.
115
+ NumberInput(name="count", value=10),
116
+ ],
117
+ [VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="NUMBER", key="count")],
118
+ [
119
+ CompiledChatMessagePromptBlock(
120
+ role="USER",
121
+ blocks=[CompiledValuePromptBlock(content=StringVellumValue(value="Count to 10.0"))],
122
+ ),
123
+ ],
124
+ ),
97
125
  ],
98
126
  ids=[
99
127
  "empty",
@@ -102,6 +130,7 @@ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, Compiled
102
130
  "rich-text-no-variables",
103
131
  "rich-text-with-variables",
104
132
  "chat-message",
133
+ "number-input",
105
134
  ],
106
135
  )
107
136
  def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected):
@@ -10,7 +10,7 @@ import pydash
10
10
  import pytz
11
11
  import yaml
12
12
 
13
- from vellum.utils.templating.custom_filters import is_valid_json_string
13
+ from vellum.utils.templating.custom_filters import is_valid_json_string, replace
14
14
 
15
15
  DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
16
16
  "datetime": datetime,
@@ -23,6 +23,10 @@ DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
23
23
  "re": re,
24
24
  "yaml": yaml,
25
25
  }
26
- DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
26
+
27
+ FilterFunc = Union[Callable[[Union[str, bytes]], bool], Callable[[Any, Any, Any], str]]
28
+
29
+ DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, FilterFunc] = {
27
30
  "is_valid_json_string": is_valid_json_string,
31
+ "replace": replace,
28
32
  }
@@ -1,5 +1,7 @@
1
1
  import json
2
- from typing import Union
2
+ from typing import Any, Union
3
+
4
+ from vellum.workflows.state.encoder import DefaultStateEncoder
3
5
 
4
6
 
5
7
  def is_valid_json_string(value: Union[str, bytes]) -> bool:
@@ -10,3 +12,22 @@ def is_valid_json_string(value: Union[str, bytes]) -> bool:
10
12
  except ValueError:
11
13
  return False
12
14
  return True
15
+
16
+
17
+ def replace(s: Any, old: Any, new: Any) -> str:
18
+ def encode_to_str(obj: Any) -> str:
19
+ """Encode an object for template rendering using DefaultStateEncoder."""
20
+ try:
21
+ if isinstance(obj, str):
22
+ return obj
23
+ return json.dumps(obj, cls=DefaultStateEncoder)
24
+ except TypeError:
25
+ return str(obj)
26
+
27
+ if old == "":
28
+ return encode_to_str(s)
29
+
30
+ s_str = encode_to_str(s)
31
+ old_str = encode_to_str(old)
32
+ new_str = encode_to_str(new)
33
+ return s_str.replace(old_str, new_str)
@@ -1,8 +1,9 @@
1
1
  import json
2
- from typing import Any, Callable, Dict, Optional, Union
2
+ from typing import Any, Dict, Optional
3
3
 
4
4
  from jinja2.sandbox import SandboxedEnvironment
5
5
 
6
+ from vellum.utils.templating.constants import FilterFunc
6
7
  from vellum.utils.templating.exceptions import JinjaTemplateError
7
8
  from vellum.workflows.state.encoder import DefaultStateEncoder
8
9
 
@@ -18,7 +19,7 @@ def render_sandboxed_jinja_template(
18
19
  *,
19
20
  template: str,
20
21
  input_values: Dict[str, Any],
21
- jinja_custom_filters: Optional[Dict[str, Callable[[Union[str, bytes]], bool]]] = None,
22
+ jinja_custom_filters: Optional[Dict[str, FilterFunc]] = None,
22
23
  jinja_globals: Optional[Dict[str, Any]] = None,
23
24
  ) -> str:
24
25
  """Render a Jinja template within a sandboxed environment."""
@@ -42,6 +43,8 @@ def render_sandboxed_jinja_template(
42
43
 
43
44
  rendered_template = jinja_template.render(input_values)
44
45
  except json.JSONDecodeError as e:
46
+ if not e.doc:
47
+ raise JinjaTemplateError("Unable to render jinja template:\n" "Cannot run json.loads() on empty input")
45
48
  if e.msg == "Invalid control character at":
46
49
  raise JinjaTemplateError(
47
50
  "Unable to render jinja template:\n"
File without changes
@@ -0,0 +1,19 @@
1
+ import pytest
2
+
3
+ from vellum.utils.templating.custom_filters import replace
4
+
5
+
6
+ @pytest.mark.parametrize(
7
+ ["input_str", "old", "new", "expected"],
8
+ [
9
+ ("foo", "foo", "bar", "bar"),
10
+ ({"message": "hello"}, "hello", "world", '{"message": "world"}'),
11
+ ("Value: 123", 123, 456, "Value: 456"),
12
+ (123, 2, 4, "143"),
13
+ ("", "", "", ""),
14
+ ("foo", "", "bar", "foo"),
15
+ ],
16
+ )
17
+ def test_replace(input_str, old, new, expected):
18
+ actual = replace(input_str, old, new)
19
+ assert actual == expected
@@ -13,6 +13,7 @@ class WorkflowErrorCode(Enum):
13
13
  INVALID_INPUTS = "INVALID_INPUTS"
14
14
  INVALID_OUTPUTS = "INVALID_OUTPUTS"
15
15
  INVALID_STATE = "INVALID_STATE"
16
+ INVALID_CODE = "INVALID_CODE"
16
17
  INVALID_TEMPLATE = "INVALID_TEMPLATE"
17
18
  INTERNAL_ERROR = "INTERNAL_ERROR"
18
19
  NODE_EXECUTION = "NODE_EXECUTION"
@@ -54,6 +55,7 @@ _WORKFLOW_EVENT_ERROR_CODE_TO_WORKFLOW_ERROR_CODE: Dict[WorkflowExecutionEventEr
54
55
  "INTERNAL_SERVER_ERROR": WorkflowErrorCode.INTERNAL_ERROR,
55
56
  "NODE_EXECUTION": WorkflowErrorCode.NODE_EXECUTION,
56
57
  "LLM_PROVIDER": WorkflowErrorCode.PROVIDER_ERROR,
58
+ "INVALID_CODE": WorkflowErrorCode.INVALID_CODE,
57
59
  "INVALID_TEMPLATE": WorkflowErrorCode.INVALID_TEMPLATE,
58
60
  "USER_DEFINED_ERROR": WorkflowErrorCode.USER_DEFINED_ERROR,
59
61
  }
@@ -71,6 +73,7 @@ _WORKFLOW_ERROR_CODE_TO_VELLUM_ERROR_CODE: Dict[WorkflowErrorCode, VellumErrorCo
71
73
  WorkflowErrorCode.INVALID_INPUTS: "INVALID_INPUTS",
72
74
  WorkflowErrorCode.INVALID_OUTPUTS: "INVALID_REQUEST",
73
75
  WorkflowErrorCode.INVALID_STATE: "INVALID_REQUEST",
76
+ WorkflowErrorCode.INVALID_CODE: "INVALID_CODE",
74
77
  WorkflowErrorCode.INVALID_TEMPLATE: "INVALID_INPUTS",
75
78
  WorkflowErrorCode.INTERNAL_ERROR: "INTERNAL_SERVER_ERROR",
76
79
  WorkflowErrorCode.NODE_EXECUTION: "USER_DEFINED_ERROR",
@@ -1,6 +1,6 @@
1
- from typing import Any, Callable, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, Union, get_args
1
+ from typing import Any, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, get_args
2
2
 
3
- from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS
3
+ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS, FilterFunc
4
4
  from vellum.utils.templating.exceptions import JinjaTemplateError
5
5
  from vellum.utils.templating.render import render_sandboxed_jinja_template
6
6
  from vellum.workflows.errors import WorkflowErrorCode
@@ -54,7 +54,7 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
54
54
  inputs: ClassVar[EntityInputsInterface]
55
55
 
56
56
  jinja_globals: Dict[str, Any] = DEFAULT_JINJA_GLOBALS
57
- jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = DEFAULT_JINJA_CUSTOM_FILTERS
57
+ jinja_custom_filters: Mapping[str, FilterFunc] = DEFAULT_JINJA_CUSTOM_FILTERS
58
58
 
59
59
  class Outputs(BaseNode.Outputs):
60
60
  """
@@ -1,8 +1,11 @@
1
+ import pytest
1
2
  import json
2
- from typing import List
3
+ from typing import List, Union
3
4
 
4
5
  from vellum.client.types.chat_message import ChatMessage
5
6
  from vellum.client.types.function_call import FunctionCall
7
+ from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
8
+ from vellum.workflows.exceptions import NodeException
6
9
  from vellum.workflows.nodes.bases.base import BaseNode
7
10
  from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
8
11
  from vellum.workflows.state import BaseState
@@ -155,3 +158,65 @@ def test_templating_node__function_call_output():
155
158
 
156
159
  # THEN the output is the expected function call
157
160
  assert outputs.result == FunctionCall(name="test", arguments={"key": "value"})
161
+
162
+
163
+ def test_templating_node__blank_json_input():
164
+ """Test that templating node properly handles blank JSON input."""
165
+
166
+ # GIVEN a templating node that tries to parse blank JSON
167
+ class BlankJsonTemplateNode(TemplatingNode[BaseState, Json]):
168
+ template = "{{ json.loads(data) }}"
169
+ inputs = {
170
+ "data": "", # Blank input
171
+ }
172
+
173
+ # WHEN the node is run
174
+ node = BlankJsonTemplateNode()
175
+
176
+ # THEN it should raise an appropriate error
177
+ with pytest.raises(NodeException) as exc_info:
178
+ node.run()
179
+
180
+ assert "Unable to render jinja template:\nCannot run json.loads() on empty input" in str(exc_info.value)
181
+
182
+
183
+ def test_templating_node__union_float_int_output():
184
+ # GIVEN a templating node that outputs either a float or an int
185
+ class UnionTemplateNode(TemplatingNode[BaseState, Union[float, int]]):
186
+ template = """{{ obj[\"score\"] | float }}"""
187
+ inputs = {
188
+ "obj": {"score": 42.5},
189
+ }
190
+
191
+ # WHEN the node is run
192
+ node = UnionTemplateNode()
193
+ outputs = node.run()
194
+
195
+ # THEN it should correctly parse as a float
196
+ assert outputs.result == 42.5
197
+
198
+
199
+ def test_templating_node__replace_filter():
200
+ # GIVEN a templating node that outputs a complex object
201
+ class ReplaceFilterTemplateNode(TemplatingNode[BaseState, Json]):
202
+ template = """{{- prompt_outputs | selectattr(\'type\', \'equalto\', \'FUNCTION_CALL\') \
203
+ | list | replace(\"\\n\",\",\") -}}"""
204
+ inputs = {
205
+ "prompt_outputs": [FunctionCallVellumValue(value=FunctionCall(name="test", arguments={"key": "value"}))]
206
+ }
207
+
208
+ # WHEN the node is run
209
+ node = ReplaceFilterTemplateNode()
210
+ outputs = node.run()
211
+
212
+ # THEN the output is the expected JSON
213
+ assert outputs.result == [
214
+ {
215
+ "type": "FUNCTION_CALL",
216
+ "value": {
217
+ "name": "test",
218
+ "arguments": {"key": "value"},
219
+ "id": None,
220
+ },
221
+ }
222
+ ]
@@ -93,13 +93,14 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
93
93
  log: str
94
94
 
95
95
  def run(self) -> Outputs:
96
- input_values = self._compile_code_inputs()
97
96
  output_type = self.__class__.get_output_type()
98
97
  code = self._resolve_code()
99
98
  if not self.packages and self.runtime == "PYTHON_3_11_6":
100
- logs, result = run_code_inline(code, input_values, output_type)
99
+ logs, result = run_code_inline(code, self.code_inputs, output_type)
101
100
  return self.Outputs(result=result, log=logs)
101
+
102
102
  else:
103
+ input_values = self._compile_code_inputs()
103
104
  expected_output_type = primitive_type_to_vellum_variable_type(output_type)
104
105
 
105
106
  code_execution_result = self._context.vellum_client.execute_code(
@@ -131,7 +132,7 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
131
132
  compiled_inputs.append(
132
133
  StringInput(
133
134
  name=input_name,
134
- value=str(input_value),
135
+ value=input_value,
135
136
  )
136
137
  )
137
138
  elif isinstance(input_value, VellumSecret):
@@ -193,7 +194,7 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
193
194
  )
194
195
  else:
195
196
  raise NodeException(
196
- message=f"Unrecognized input type for input '{input_name}'",
197
+ message=f"Unrecognized input type for input '{input_name}': {input_value.__class__.__name__}",
197
198
  code=WorkflowErrorCode.INVALID_INPUTS,
198
199
  )
199
200
 
@@ -7,6 +7,7 @@ from vellum.client.types.code_execution_package import CodeExecutionPackage
7
7
  from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
8
8
  from vellum.client.types.function_call import FunctionCall
9
9
  from vellum.client.types.number_input import NumberInput
10
+ from vellum.workflows.errors import WorkflowErrorCode
10
11
  from vellum.workflows.exceptions import NodeException
11
12
  from vellum.workflows.inputs.base import BaseInputs
12
13
  from vellum.workflows.nodes.displayable.code_execution_node import CodeExecutionNode
@@ -559,3 +560,53 @@ def main(arg1: List[Dict]) -> float:
559
560
 
560
561
  # AND we should not have invoked the Code via Vellum since it's running inline
561
562
  vellum_client.execute_code.assert_not_called()
563
+
564
+
565
+ def test_run_node__code_execution_error():
566
+ # GIVEN a node that will raise an error during execution
567
+ class State(BaseState):
568
+ pass
569
+
570
+ class ExampleCodeExecutionNode(CodeExecutionNode[State, int]):
571
+ code = """\
572
+ def main(arg1: int, arg2: int) -> int:
573
+ return arg1 + arg2 + arg3
574
+ """
575
+ runtime = "PYTHON_3_11_6"
576
+ code_inputs = {
577
+ "arg1": 1,
578
+ "arg2": 2,
579
+ }
580
+
581
+ # WHEN we run the node
582
+ node = ExampleCodeExecutionNode(state=State())
583
+
584
+ # THEN it should raise a NodeException with the execution error
585
+ with pytest.raises(NodeException) as exc_info:
586
+ node.run()
587
+
588
+ # AND the error should contain the execution error details
589
+ assert "name 'arg3' is not defined" in str(exc_info.value)
590
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_CODE
591
+
592
+
593
+ def test_run_node__array_of_bools_input():
594
+ # GIVEN a node that will raise an error during execution
595
+ class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, int]):
596
+ code = """\
597
+ def main(arg1: list[bool]) -> int:
598
+ return len(arg1)
599
+ """
600
+ runtime = "PYTHON_3_11_6"
601
+ code_inputs = {
602
+ "arg1": [True, False, True],
603
+ }
604
+
605
+ # WHEN we run the node
606
+ node = ExampleCodeExecutionNode()
607
+
608
+ # THEN it should raise a NodeException with the execution error
609
+ outputs = node.run()
610
+
611
+ # AND the error should contain the execution error details
612
+ assert outputs == {"result": 3, "log": ""}
@@ -1,14 +1,13 @@
1
1
  import io
2
2
  import os
3
3
  import re
4
- from typing import Any, List, Tuple, Union, get_args, get_origin
4
+ from typing import Any, Tuple, Union, get_args, get_origin
5
5
 
6
6
  from pydantic import BaseModel, ValidationError
7
7
 
8
- from vellum import VellumValue
9
- from vellum.client.types.code_executor_input import CodeExecutorInput
10
8
  from vellum.workflows.errors.types import WorkflowErrorCode
11
9
  from vellum.workflows.exceptions import NodeException
10
+ from vellum.workflows.types.core import EntityInputsInterface
12
11
 
13
12
 
14
13
  def read_file_from_path(node_filepath: str, script_filepath: str) -> Union[str, None]:
@@ -70,13 +69,11 @@ def _clean_for_dict_wrapper(obj):
70
69
 
71
70
  def run_code_inline(
72
71
  code: str,
73
- input_values: List[CodeExecutorInput],
72
+ inputs: EntityInputsInterface,
74
73
  output_type: Any,
75
74
  ) -> Tuple[str, Any]:
76
75
  log_buffer = io.StringIO()
77
76
 
78
- VELLUM_TYPES = get_args(VellumValue)
79
-
80
77
  def wrap_value(value):
81
78
  if isinstance(value, list):
82
79
  return ListWrapper(
@@ -84,7 +81,7 @@ def run_code_inline(
84
81
  # Convert VellumValue to dict with its fields
85
82
  (
86
83
  item.model_dump()
87
- if isinstance(item, VELLUM_TYPES)
84
+ if isinstance(item, BaseModel)
88
85
  else _clean_for_dict_wrapper(item) if isinstance(item, (dict, list)) else item
89
86
  )
90
87
  for item in value
@@ -93,18 +90,23 @@ def run_code_inline(
93
90
  return _clean_for_dict_wrapper(value)
94
91
 
95
92
  exec_globals = {
96
- "__arg__inputs": {input_value.name: wrap_value(input_value.value) for input_value in input_values},
93
+ "__arg__inputs": {name: wrap_value(value) for name, value in inputs.items()},
97
94
  "__arg__out": None,
98
95
  "print": lambda *args, **kwargs: log_buffer.write(f"{' '.join(args)}\n"),
99
96
  }
100
- run_args = [f"{input_value.name}=__arg__inputs['{input_value.name}']" for input_value in input_values]
97
+ run_args = [f"{name}=__arg__inputs['{name}']" for name in inputs.keys()]
101
98
  execution_code = f"""\
102
99
  {code}
103
100
 
104
101
  __arg__out = main({", ".join(run_args)})
105
102
  """
106
-
107
- exec(execution_code, exec_globals)
103
+ try:
104
+ exec(execution_code, exec_globals)
105
+ except Exception as e:
106
+ raise NodeException(
107
+ code=WorkflowErrorCode.INVALID_CODE,
108
+ message=str(e),
109
+ )
108
110
 
109
111
  logs = log_buffer.getvalue()
110
112
  result = exec_globals["__arg__out"]
@@ -2,7 +2,7 @@ from functools import cache
2
2
  import json
3
3
  import sys
4
4
  from types import ModuleType
5
- from typing import Any, Callable, Optional, Type, TypeVar, get_args, get_origin
5
+ from typing import Any, Callable, Optional, Type, TypeVar, Union, get_args, get_origin
6
6
 
7
7
  from pydantic import BaseModel
8
8
 
@@ -113,6 +113,14 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
113
113
  except json.JSONDecodeError:
114
114
  raise ValueError("Invalid JSON format for result_as_str")
115
115
 
116
+ if get_origin(output_type) is Union:
117
+ for inner_type in get_args(output_type):
118
+ try:
119
+ return parse_type_from_str(result_as_str, inner_type)
120
+ except ValueError:
121
+ continue
122
+ raise ValueError(f"Could not parse with any of the Union types: {output_type}")
123
+
116
124
  if issubclass(output_type, BaseModel):
117
125
  try:
118
126
  data = json.loads(result_as_str)
@@ -18,7 +18,13 @@ from vellum.workflows.inputs.base import BaseInputs
18
18
  from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
19
19
  from vellum.workflows.types.generics import StateType
20
20
  from vellum.workflows.types.stack import Stack
21
- from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_by_qualname, infer_types
21
+ from vellum.workflows.types.utils import (
22
+ datetime_now,
23
+ deepcopy_with_exclusions,
24
+ get_class_attr_names,
25
+ get_class_by_qualname,
26
+ infer_types,
27
+ )
22
28
 
23
29
  if TYPE_CHECKING:
24
30
  from vellum.workflows.nodes.bases import BaseNode
@@ -47,6 +53,20 @@ class _BaseStateMeta(type):
47
53
 
48
54
  return super().__getattribute__(name)
49
55
 
56
+ def __iter__(cls) -> Iterator[StateValueReference]:
57
+ # We iterate through the inheritance hierarchy to find all the StateValueReference attached to this
58
+ # Inputs class. __mro__ is the method resolution order, which is the order in which base classes are resolved.
59
+ for resolved_cls in cls.__mro__:
60
+ attr_names = get_class_attr_names(resolved_cls)
61
+ for attr_name in attr_names:
62
+ if attr_name == "meta":
63
+ continue
64
+ attr_value = getattr(resolved_cls, attr_name)
65
+ if not isinstance(attr_value, (StateValueReference)):
66
+ continue
67
+
68
+ yield attr_value
69
+
50
70
 
51
71
  class _SnapshottableDict(dict, _Snapshottable):
52
72
  def __setitem__(self, key: Any, value: Any) -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 0.14.0
3
+ Version: 0.14.2
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0