vellum-ai 0.14.0__py3-none-any.whl → 0.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/prompts/blocks/compilation.py +23 -16
- vellum/prompts/blocks/tests/test_compilation.py +29 -0
- vellum/utils/templating/constants.py +6 -2
- vellum/utils/templating/custom_filters.py +22 -1
- vellum/utils/templating/render.py +5 -2
- vellum/utils/templating/tests/__init__.py +0 -0
- vellum/utils/templating/tests/test_custom_filters.py +19 -0
- vellum/workflows/errors/types.py +3 -0
- vellum/workflows/nodes/core/templating_node/node.py +3 -3
- vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +66 -1
- vellum/workflows/nodes/displayable/code_execution_node/node.py +5 -4
- vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +51 -0
- vellum/workflows/nodes/displayable/code_execution_node/utils.py +13 -11
- vellum/workflows/nodes/utils.py +9 -1
- vellum/workflows/state/base.py +21 -1
- {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/RECORD +51 -46
- vellum_cli/pull.py +3 -11
- vellum_ee/workflows/display/base.py +14 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +11 -22
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +52 -0
- vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +243 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +1 -0
- vellum_ee/workflows/display/types.py +5 -1
- vellum_ee/workflows/display/utils/expressions.py +26 -0
- vellum_ee/workflows/display/utils/vellum.py +5 -0
- vellum_ee/workflows/display/vellum.py +15 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +30 -1
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +54 -6
- vellum_ee/workflows/tests/local_workflow/display/workflow.py +0 -2
- {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.0.dist-info → vellum_ai-0.14.2.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.2",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import
|
2
|
+
from typing import Sequence, Union, cast
|
3
3
|
|
4
4
|
from vellum import (
|
5
5
|
ChatMessage,
|
@@ -14,6 +14,7 @@ from vellum.client.types.audio_vellum_value import AudioVellumValue
|
|
14
14
|
from vellum.client.types.function_call import FunctionCall
|
15
15
|
from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
|
16
16
|
from vellum.client.types.image_vellum_value import ImageVellumValue
|
17
|
+
from vellum.client.types.number_input import NumberInput
|
17
18
|
from vellum.client.types.vellum_audio import VellumAudio
|
18
19
|
from vellum.client.types.vellum_image import VellumImage
|
19
20
|
from vellum.prompts.blocks.exceptions import PromptCompilationError
|
@@ -22,15 +23,20 @@ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS
|
|
22
23
|
from vellum.utils.templating.render import render_sandboxed_jinja_template
|
23
24
|
from vellum.utils.typing import cast_not_optional
|
24
25
|
|
26
|
+
PromptInput = Union[PromptRequestInput, NumberInput]
|
27
|
+
|
25
28
|
|
26
29
|
def compile_prompt_blocks(
|
27
30
|
blocks: list[PromptBlock],
|
28
|
-
inputs:
|
31
|
+
inputs: Sequence[PromptInput],
|
29
32
|
input_variables: list[VellumVariable],
|
30
33
|
) -> list[CompiledPromptBlock]:
|
31
34
|
"""Compiles a list of Prompt Blocks, performing all variable substitutions and Jinja templating needed."""
|
32
35
|
|
33
36
|
sanitized_inputs = _sanitize_inputs(inputs)
|
37
|
+
inputs_by_name = {
|
38
|
+
(input_.name if isinstance(input_, NumberInput) else input_.key): input_ for input_ in sanitized_inputs
|
39
|
+
}
|
34
40
|
|
35
41
|
compiled_blocks: list[CompiledPromptBlock] = []
|
36
42
|
for block in blocks:
|
@@ -66,7 +72,7 @@ def compile_prompt_blocks(
|
|
66
72
|
|
67
73
|
rendered_template = render_sandboxed_jinja_template(
|
68
74
|
template=block.template,
|
69
|
-
input_values={
|
75
|
+
input_values={name: inp.value for name, inp in inputs_by_name.items()},
|
70
76
|
jinja_custom_filters=DEFAULT_JINJA_CUSTOM_FILTERS,
|
71
77
|
jinja_globals=DEFAULT_JINJA_CUSTOM_FILTERS,
|
72
78
|
)
|
@@ -80,9 +86,7 @@ def compile_prompt_blocks(
|
|
80
86
|
)
|
81
87
|
|
82
88
|
elif block.block_type == "VARIABLE":
|
83
|
-
compiled_input
|
84
|
-
(input_ for input_ in sanitized_inputs if input_.key == str(block.input_variable)), None
|
85
|
-
)
|
89
|
+
compiled_input = inputs_by_name.get(block.input_variable)
|
86
90
|
if compiled_input is None:
|
87
91
|
raise PromptCompilationError(f"Input variable '{block.input_variable}' not found")
|
88
92
|
|
@@ -196,23 +200,26 @@ def _compile_chat_messages_as_prompt_blocks(chat_messages: list[ChatMessage]) ->
|
|
196
200
|
|
197
201
|
def _compile_rich_text_block_as_value_block(
|
198
202
|
block: RichTextPromptBlock,
|
199
|
-
inputs: list[
|
203
|
+
inputs: list[PromptInput],
|
200
204
|
) -> CompiledValuePromptBlock:
|
201
205
|
value: str = ""
|
206
|
+
inputs_by_name = {(input_.name if isinstance(input_, NumberInput) else input_.key): input_ for input_ in inputs}
|
202
207
|
for child_block in block.blocks:
|
203
208
|
if child_block.block_type == "PLAIN_TEXT":
|
204
209
|
value += child_block.text
|
205
210
|
elif child_block.block_type == "VARIABLE":
|
206
|
-
|
207
|
-
if
|
211
|
+
input = inputs_by_name.get(child_block.input_variable)
|
212
|
+
if input is None:
|
208
213
|
raise PromptCompilationError(f"Input variable '{child_block.input_variable}' not found")
|
209
|
-
elif
|
210
|
-
value += str(
|
211
|
-
elif
|
212
|
-
value += json.dumps(
|
214
|
+
elif input.type == "STRING":
|
215
|
+
value += str(input.value)
|
216
|
+
elif input.type == "JSON":
|
217
|
+
value += json.dumps(input.value, indent=4)
|
218
|
+
elif input.type == "NUMBER":
|
219
|
+
value += str(input.value)
|
213
220
|
else:
|
214
221
|
raise PromptCompilationError(
|
215
|
-
f"Input variable '{child_block.input_variable}'
|
222
|
+
f"Input variable '{child_block.input_variable}' has an invalid type: {input.type}"
|
216
223
|
)
|
217
224
|
else:
|
218
225
|
raise PromptCompilationError(f"Invalid child block_type for RICH_TEXT: {child_block.block_type}")
|
@@ -220,8 +227,8 @@ def _compile_rich_text_block_as_value_block(
|
|
220
227
|
return CompiledValuePromptBlock(content=StringVellumValue(value=value), cache_config=block.cache_config)
|
221
228
|
|
222
229
|
|
223
|
-
def _sanitize_inputs(inputs:
|
224
|
-
sanitized_inputs: list[
|
230
|
+
def _sanitize_inputs(inputs: Sequence[PromptInput]) -> list[PromptInput]:
|
231
|
+
sanitized_inputs: list[PromptInput] = []
|
225
232
|
for input_ in inputs:
|
226
233
|
if input_.type == "CHAT_HISTORY" and input_.value is None:
|
227
234
|
sanitized_inputs.append(input_.model_copy(update={"value": cast(list[ChatMessage], [])}))
|
@@ -10,6 +10,7 @@ from vellum import (
|
|
10
10
|
VariablePromptBlock,
|
11
11
|
VellumVariable,
|
12
12
|
)
|
13
|
+
from vellum.client.types.number_input import NumberInput
|
13
14
|
from vellum.prompts.blocks.compilation import compile_prompt_blocks
|
14
15
|
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
|
15
16
|
|
@@ -94,6 +95,33 @@ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, Compiled
|
|
94
95
|
),
|
95
96
|
],
|
96
97
|
),
|
98
|
+
(
|
99
|
+
[
|
100
|
+
ChatMessagePromptBlock(
|
101
|
+
chat_role="USER",
|
102
|
+
blocks=[
|
103
|
+
RichTextPromptBlock(
|
104
|
+
blocks=[
|
105
|
+
PlainTextPromptBlock(text="Count to "),
|
106
|
+
VariablePromptBlock(input_variable="count"),
|
107
|
+
]
|
108
|
+
)
|
109
|
+
],
|
110
|
+
)
|
111
|
+
],
|
112
|
+
[
|
113
|
+
# TODO: We don't yet have PromptRequestNumberInput. We should either add it or migrate
|
114
|
+
# Prompts to using these more generic inputs.
|
115
|
+
NumberInput(name="count", value=10),
|
116
|
+
],
|
117
|
+
[VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="NUMBER", key="count")],
|
118
|
+
[
|
119
|
+
CompiledChatMessagePromptBlock(
|
120
|
+
role="USER",
|
121
|
+
blocks=[CompiledValuePromptBlock(content=StringVellumValue(value="Count to 10.0"))],
|
122
|
+
),
|
123
|
+
],
|
124
|
+
),
|
97
125
|
],
|
98
126
|
ids=[
|
99
127
|
"empty",
|
@@ -102,6 +130,7 @@ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, Compiled
|
|
102
130
|
"rich-text-no-variables",
|
103
131
|
"rich-text-with-variables",
|
104
132
|
"chat-message",
|
133
|
+
"number-input",
|
105
134
|
],
|
106
135
|
)
|
107
136
|
def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected):
|
@@ -10,7 +10,7 @@ import pydash
|
|
10
10
|
import pytz
|
11
11
|
import yaml
|
12
12
|
|
13
|
-
from vellum.utils.templating.custom_filters import is_valid_json_string
|
13
|
+
from vellum.utils.templating.custom_filters import is_valid_json_string, replace
|
14
14
|
|
15
15
|
DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
|
16
16
|
"datetime": datetime,
|
@@ -23,6 +23,10 @@ DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
|
|
23
23
|
"re": re,
|
24
24
|
"yaml": yaml,
|
25
25
|
}
|
26
|
-
|
26
|
+
|
27
|
+
FilterFunc = Union[Callable[[Union[str, bytes]], bool], Callable[[Any, Any, Any], str]]
|
28
|
+
|
29
|
+
DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, FilterFunc] = {
|
27
30
|
"is_valid_json_string": is_valid_json_string,
|
31
|
+
"replace": replace,
|
28
32
|
}
|
@@ -1,5 +1,7 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Union
|
2
|
+
from typing import Any, Union
|
3
|
+
|
4
|
+
from vellum.workflows.state.encoder import DefaultStateEncoder
|
3
5
|
|
4
6
|
|
5
7
|
def is_valid_json_string(value: Union[str, bytes]) -> bool:
|
@@ -10,3 +12,22 @@ def is_valid_json_string(value: Union[str, bytes]) -> bool:
|
|
10
12
|
except ValueError:
|
11
13
|
return False
|
12
14
|
return True
|
15
|
+
|
16
|
+
|
17
|
+
def replace(s: Any, old: Any, new: Any) -> str:
|
18
|
+
def encode_to_str(obj: Any) -> str:
|
19
|
+
"""Encode an object for template rendering using DefaultStateEncoder."""
|
20
|
+
try:
|
21
|
+
if isinstance(obj, str):
|
22
|
+
return obj
|
23
|
+
return json.dumps(obj, cls=DefaultStateEncoder)
|
24
|
+
except TypeError:
|
25
|
+
return str(obj)
|
26
|
+
|
27
|
+
if old == "":
|
28
|
+
return encode_to_str(s)
|
29
|
+
|
30
|
+
s_str = encode_to_str(s)
|
31
|
+
old_str = encode_to_str(old)
|
32
|
+
new_str = encode_to_str(new)
|
33
|
+
return s_str.replace(old_str, new_str)
|
@@ -1,8 +1,9 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Dict, Optional
|
3
3
|
|
4
4
|
from jinja2.sandbox import SandboxedEnvironment
|
5
5
|
|
6
|
+
from vellum.utils.templating.constants import FilterFunc
|
6
7
|
from vellum.utils.templating.exceptions import JinjaTemplateError
|
7
8
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
8
9
|
|
@@ -18,7 +19,7 @@ def render_sandboxed_jinja_template(
|
|
18
19
|
*,
|
19
20
|
template: str,
|
20
21
|
input_values: Dict[str, Any],
|
21
|
-
jinja_custom_filters: Optional[Dict[str,
|
22
|
+
jinja_custom_filters: Optional[Dict[str, FilterFunc]] = None,
|
22
23
|
jinja_globals: Optional[Dict[str, Any]] = None,
|
23
24
|
) -> str:
|
24
25
|
"""Render a Jinja template within a sandboxed environment."""
|
@@ -42,6 +43,8 @@ def render_sandboxed_jinja_template(
|
|
42
43
|
|
43
44
|
rendered_template = jinja_template.render(input_values)
|
44
45
|
except json.JSONDecodeError as e:
|
46
|
+
if not e.doc:
|
47
|
+
raise JinjaTemplateError("Unable to render jinja template:\n" "Cannot run json.loads() on empty input")
|
45
48
|
if e.msg == "Invalid control character at":
|
46
49
|
raise JinjaTemplateError(
|
47
50
|
"Unable to render jinja template:\n"
|
File without changes
|
@@ -0,0 +1,19 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.utils.templating.custom_filters import replace
|
4
|
+
|
5
|
+
|
6
|
+
@pytest.mark.parametrize(
|
7
|
+
["input_str", "old", "new", "expected"],
|
8
|
+
[
|
9
|
+
("foo", "foo", "bar", "bar"),
|
10
|
+
({"message": "hello"}, "hello", "world", '{"message": "world"}'),
|
11
|
+
("Value: 123", 123, 456, "Value: 456"),
|
12
|
+
(123, 2, 4, "143"),
|
13
|
+
("", "", "", ""),
|
14
|
+
("foo", "", "bar", "foo"),
|
15
|
+
],
|
16
|
+
)
|
17
|
+
def test_replace(input_str, old, new, expected):
|
18
|
+
actual = replace(input_str, old, new)
|
19
|
+
assert actual == expected
|
vellum/workflows/errors/types.py
CHANGED
@@ -13,6 +13,7 @@ class WorkflowErrorCode(Enum):
|
|
13
13
|
INVALID_INPUTS = "INVALID_INPUTS"
|
14
14
|
INVALID_OUTPUTS = "INVALID_OUTPUTS"
|
15
15
|
INVALID_STATE = "INVALID_STATE"
|
16
|
+
INVALID_CODE = "INVALID_CODE"
|
16
17
|
INVALID_TEMPLATE = "INVALID_TEMPLATE"
|
17
18
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
18
19
|
NODE_EXECUTION = "NODE_EXECUTION"
|
@@ -54,6 +55,7 @@ _WORKFLOW_EVENT_ERROR_CODE_TO_WORKFLOW_ERROR_CODE: Dict[WorkflowExecutionEventEr
|
|
54
55
|
"INTERNAL_SERVER_ERROR": WorkflowErrorCode.INTERNAL_ERROR,
|
55
56
|
"NODE_EXECUTION": WorkflowErrorCode.NODE_EXECUTION,
|
56
57
|
"LLM_PROVIDER": WorkflowErrorCode.PROVIDER_ERROR,
|
58
|
+
"INVALID_CODE": WorkflowErrorCode.INVALID_CODE,
|
57
59
|
"INVALID_TEMPLATE": WorkflowErrorCode.INVALID_TEMPLATE,
|
58
60
|
"USER_DEFINED_ERROR": WorkflowErrorCode.USER_DEFINED_ERROR,
|
59
61
|
}
|
@@ -71,6 +73,7 @@ _WORKFLOW_ERROR_CODE_TO_VELLUM_ERROR_CODE: Dict[WorkflowErrorCode, VellumErrorCo
|
|
71
73
|
WorkflowErrorCode.INVALID_INPUTS: "INVALID_INPUTS",
|
72
74
|
WorkflowErrorCode.INVALID_OUTPUTS: "INVALID_REQUEST",
|
73
75
|
WorkflowErrorCode.INVALID_STATE: "INVALID_REQUEST",
|
76
|
+
WorkflowErrorCode.INVALID_CODE: "INVALID_CODE",
|
74
77
|
WorkflowErrorCode.INVALID_TEMPLATE: "INVALID_INPUTS",
|
75
78
|
WorkflowErrorCode.INTERNAL_ERROR: "INTERNAL_SERVER_ERROR",
|
76
79
|
WorkflowErrorCode.NODE_EXECUTION: "USER_DEFINED_ERROR",
|
@@ -1,6 +1,6 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, get_args
|
2
2
|
|
3
|
-
from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS
|
3
|
+
from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS, FilterFunc
|
4
4
|
from vellum.utils.templating.exceptions import JinjaTemplateError
|
5
5
|
from vellum.utils.templating.render import render_sandboxed_jinja_template
|
6
6
|
from vellum.workflows.errors import WorkflowErrorCode
|
@@ -54,7 +54,7 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
|
|
54
54
|
inputs: ClassVar[EntityInputsInterface]
|
55
55
|
|
56
56
|
jinja_globals: Dict[str, Any] = DEFAULT_JINJA_GLOBALS
|
57
|
-
jinja_custom_filters: Mapping[str,
|
57
|
+
jinja_custom_filters: Mapping[str, FilterFunc] = DEFAULT_JINJA_CUSTOM_FILTERS
|
58
58
|
|
59
59
|
class Outputs(BaseNode.Outputs):
|
60
60
|
"""
|
@@ -1,8 +1,11 @@
|
|
1
|
+
import pytest
|
1
2
|
import json
|
2
|
-
from typing import List
|
3
|
+
from typing import List, Union
|
3
4
|
|
4
5
|
from vellum.client.types.chat_message import ChatMessage
|
5
6
|
from vellum.client.types.function_call import FunctionCall
|
7
|
+
from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
|
8
|
+
from vellum.workflows.exceptions import NodeException
|
6
9
|
from vellum.workflows.nodes.bases.base import BaseNode
|
7
10
|
from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
|
8
11
|
from vellum.workflows.state import BaseState
|
@@ -155,3 +158,65 @@ def test_templating_node__function_call_output():
|
|
155
158
|
|
156
159
|
# THEN the output is the expected function call
|
157
160
|
assert outputs.result == FunctionCall(name="test", arguments={"key": "value"})
|
161
|
+
|
162
|
+
|
163
|
+
def test_templating_node__blank_json_input():
|
164
|
+
"""Test that templating node properly handles blank JSON input."""
|
165
|
+
|
166
|
+
# GIVEN a templating node that tries to parse blank JSON
|
167
|
+
class BlankJsonTemplateNode(TemplatingNode[BaseState, Json]):
|
168
|
+
template = "{{ json.loads(data) }}"
|
169
|
+
inputs = {
|
170
|
+
"data": "", # Blank input
|
171
|
+
}
|
172
|
+
|
173
|
+
# WHEN the node is run
|
174
|
+
node = BlankJsonTemplateNode()
|
175
|
+
|
176
|
+
# THEN it should raise an appropriate error
|
177
|
+
with pytest.raises(NodeException) as exc_info:
|
178
|
+
node.run()
|
179
|
+
|
180
|
+
assert "Unable to render jinja template:\nCannot run json.loads() on empty input" in str(exc_info.value)
|
181
|
+
|
182
|
+
|
183
|
+
def test_templating_node__union_float_int_output():
|
184
|
+
# GIVEN a templating node that outputs either a float or an int
|
185
|
+
class UnionTemplateNode(TemplatingNode[BaseState, Union[float, int]]):
|
186
|
+
template = """{{ obj[\"score\"] | float }}"""
|
187
|
+
inputs = {
|
188
|
+
"obj": {"score": 42.5},
|
189
|
+
}
|
190
|
+
|
191
|
+
# WHEN the node is run
|
192
|
+
node = UnionTemplateNode()
|
193
|
+
outputs = node.run()
|
194
|
+
|
195
|
+
# THEN it should correctly parse as a float
|
196
|
+
assert outputs.result == 42.5
|
197
|
+
|
198
|
+
|
199
|
+
def test_templating_node__replace_filter():
|
200
|
+
# GIVEN a templating node that outputs a complex object
|
201
|
+
class ReplaceFilterTemplateNode(TemplatingNode[BaseState, Json]):
|
202
|
+
template = """{{- prompt_outputs | selectattr(\'type\', \'equalto\', \'FUNCTION_CALL\') \
|
203
|
+
| list | replace(\"\\n\",\",\") -}}"""
|
204
|
+
inputs = {
|
205
|
+
"prompt_outputs": [FunctionCallVellumValue(value=FunctionCall(name="test", arguments={"key": "value"}))]
|
206
|
+
}
|
207
|
+
|
208
|
+
# WHEN the node is run
|
209
|
+
node = ReplaceFilterTemplateNode()
|
210
|
+
outputs = node.run()
|
211
|
+
|
212
|
+
# THEN the output is the expected JSON
|
213
|
+
assert outputs.result == [
|
214
|
+
{
|
215
|
+
"type": "FUNCTION_CALL",
|
216
|
+
"value": {
|
217
|
+
"name": "test",
|
218
|
+
"arguments": {"key": "value"},
|
219
|
+
"id": None,
|
220
|
+
},
|
221
|
+
}
|
222
|
+
]
|
@@ -93,13 +93,14 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
|
|
93
93
|
log: str
|
94
94
|
|
95
95
|
def run(self) -> Outputs:
|
96
|
-
input_values = self._compile_code_inputs()
|
97
96
|
output_type = self.__class__.get_output_type()
|
98
97
|
code = self._resolve_code()
|
99
98
|
if not self.packages and self.runtime == "PYTHON_3_11_6":
|
100
|
-
logs, result = run_code_inline(code,
|
99
|
+
logs, result = run_code_inline(code, self.code_inputs, output_type)
|
101
100
|
return self.Outputs(result=result, log=logs)
|
101
|
+
|
102
102
|
else:
|
103
|
+
input_values = self._compile_code_inputs()
|
103
104
|
expected_output_type = primitive_type_to_vellum_variable_type(output_type)
|
104
105
|
|
105
106
|
code_execution_result = self._context.vellum_client.execute_code(
|
@@ -131,7 +132,7 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
|
|
131
132
|
compiled_inputs.append(
|
132
133
|
StringInput(
|
133
134
|
name=input_name,
|
134
|
-
value=
|
135
|
+
value=input_value,
|
135
136
|
)
|
136
137
|
)
|
137
138
|
elif isinstance(input_value, VellumSecret):
|
@@ -193,7 +194,7 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
|
|
193
194
|
)
|
194
195
|
else:
|
195
196
|
raise NodeException(
|
196
|
-
message=f"Unrecognized input type for input '{input_name}'",
|
197
|
+
message=f"Unrecognized input type for input '{input_name}': {input_value.__class__.__name__}",
|
197
198
|
code=WorkflowErrorCode.INVALID_INPUTS,
|
198
199
|
)
|
199
200
|
|
@@ -7,6 +7,7 @@ from vellum.client.types.code_execution_package import CodeExecutionPackage
|
|
7
7
|
from vellum.client.types.code_executor_secret_input import CodeExecutorSecretInput
|
8
8
|
from vellum.client.types.function_call import FunctionCall
|
9
9
|
from vellum.client.types.number_input import NumberInput
|
10
|
+
from vellum.workflows.errors import WorkflowErrorCode
|
10
11
|
from vellum.workflows.exceptions import NodeException
|
11
12
|
from vellum.workflows.inputs.base import BaseInputs
|
12
13
|
from vellum.workflows.nodes.displayable.code_execution_node import CodeExecutionNode
|
@@ -559,3 +560,53 @@ def main(arg1: List[Dict]) -> float:
|
|
559
560
|
|
560
561
|
# AND we should not have invoked the Code via Vellum since it's running inline
|
561
562
|
vellum_client.execute_code.assert_not_called()
|
563
|
+
|
564
|
+
|
565
|
+
def test_run_node__code_execution_error():
|
566
|
+
# GIVEN a node that will raise an error during execution
|
567
|
+
class State(BaseState):
|
568
|
+
pass
|
569
|
+
|
570
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[State, int]):
|
571
|
+
code = """\
|
572
|
+
def main(arg1: int, arg2: int) -> int:
|
573
|
+
return arg1 + arg2 + arg3
|
574
|
+
"""
|
575
|
+
runtime = "PYTHON_3_11_6"
|
576
|
+
code_inputs = {
|
577
|
+
"arg1": 1,
|
578
|
+
"arg2": 2,
|
579
|
+
}
|
580
|
+
|
581
|
+
# WHEN we run the node
|
582
|
+
node = ExampleCodeExecutionNode(state=State())
|
583
|
+
|
584
|
+
# THEN it should raise a NodeException with the execution error
|
585
|
+
with pytest.raises(NodeException) as exc_info:
|
586
|
+
node.run()
|
587
|
+
|
588
|
+
# AND the error should contain the execution error details
|
589
|
+
assert "name 'arg3' is not defined" in str(exc_info.value)
|
590
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_CODE
|
591
|
+
|
592
|
+
|
593
|
+
def test_run_node__array_of_bools_input():
|
594
|
+
# GIVEN a node that will raise an error during execution
|
595
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, int]):
|
596
|
+
code = """\
|
597
|
+
def main(arg1: list[bool]) -> int:
|
598
|
+
return len(arg1)
|
599
|
+
"""
|
600
|
+
runtime = "PYTHON_3_11_6"
|
601
|
+
code_inputs = {
|
602
|
+
"arg1": [True, False, True],
|
603
|
+
}
|
604
|
+
|
605
|
+
# WHEN we run the node
|
606
|
+
node = ExampleCodeExecutionNode()
|
607
|
+
|
608
|
+
# THEN it should raise a NodeException with the execution error
|
609
|
+
outputs = node.run()
|
610
|
+
|
611
|
+
# AND the error should contain the execution error details
|
612
|
+
assert outputs == {"result": 3, "log": ""}
|
@@ -1,14 +1,13 @@
|
|
1
1
|
import io
|
2
2
|
import os
|
3
3
|
import re
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Tuple, Union, get_args, get_origin
|
5
5
|
|
6
6
|
from pydantic import BaseModel, ValidationError
|
7
7
|
|
8
|
-
from vellum import VellumValue
|
9
|
-
from vellum.client.types.code_executor_input import CodeExecutorInput
|
10
8
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
11
9
|
from vellum.workflows.exceptions import NodeException
|
10
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
12
11
|
|
13
12
|
|
14
13
|
def read_file_from_path(node_filepath: str, script_filepath: str) -> Union[str, None]:
|
@@ -70,13 +69,11 @@ def _clean_for_dict_wrapper(obj):
|
|
70
69
|
|
71
70
|
def run_code_inline(
|
72
71
|
code: str,
|
73
|
-
|
72
|
+
inputs: EntityInputsInterface,
|
74
73
|
output_type: Any,
|
75
74
|
) -> Tuple[str, Any]:
|
76
75
|
log_buffer = io.StringIO()
|
77
76
|
|
78
|
-
VELLUM_TYPES = get_args(VellumValue)
|
79
|
-
|
80
77
|
def wrap_value(value):
|
81
78
|
if isinstance(value, list):
|
82
79
|
return ListWrapper(
|
@@ -84,7 +81,7 @@ def run_code_inline(
|
|
84
81
|
# Convert VellumValue to dict with its fields
|
85
82
|
(
|
86
83
|
item.model_dump()
|
87
|
-
if isinstance(item,
|
84
|
+
if isinstance(item, BaseModel)
|
88
85
|
else _clean_for_dict_wrapper(item) if isinstance(item, (dict, list)) else item
|
89
86
|
)
|
90
87
|
for item in value
|
@@ -93,18 +90,23 @@ def run_code_inline(
|
|
93
90
|
return _clean_for_dict_wrapper(value)
|
94
91
|
|
95
92
|
exec_globals = {
|
96
|
-
"__arg__inputs": {
|
93
|
+
"__arg__inputs": {name: wrap_value(value) for name, value in inputs.items()},
|
97
94
|
"__arg__out": None,
|
98
95
|
"print": lambda *args, **kwargs: log_buffer.write(f"{' '.join(args)}\n"),
|
99
96
|
}
|
100
|
-
run_args = [f"{
|
97
|
+
run_args = [f"{name}=__arg__inputs['{name}']" for name in inputs.keys()]
|
101
98
|
execution_code = f"""\
|
102
99
|
{code}
|
103
100
|
|
104
101
|
__arg__out = main({", ".join(run_args)})
|
105
102
|
"""
|
106
|
-
|
107
|
-
|
103
|
+
try:
|
104
|
+
exec(execution_code, exec_globals)
|
105
|
+
except Exception as e:
|
106
|
+
raise NodeException(
|
107
|
+
code=WorkflowErrorCode.INVALID_CODE,
|
108
|
+
message=str(e),
|
109
|
+
)
|
108
110
|
|
109
111
|
logs = log_buffer.getvalue()
|
110
112
|
result = exec_globals["__arg__out"]
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -2,7 +2,7 @@ from functools import cache
|
|
2
2
|
import json
|
3
3
|
import sys
|
4
4
|
from types import ModuleType
|
5
|
-
from typing import Any, Callable, Optional, Type, TypeVar, get_args, get_origin
|
5
|
+
from typing import Any, Callable, Optional, Type, TypeVar, Union, get_args, get_origin
|
6
6
|
|
7
7
|
from pydantic import BaseModel
|
8
8
|
|
@@ -113,6 +113,14 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
|
|
113
113
|
except json.JSONDecodeError:
|
114
114
|
raise ValueError("Invalid JSON format for result_as_str")
|
115
115
|
|
116
|
+
if get_origin(output_type) is Union:
|
117
|
+
for inner_type in get_args(output_type):
|
118
|
+
try:
|
119
|
+
return parse_type_from_str(result_as_str, inner_type)
|
120
|
+
except ValueError:
|
121
|
+
continue
|
122
|
+
raise ValueError(f"Could not parse with any of the Union types: {output_type}")
|
123
|
+
|
116
124
|
if issubclass(output_type, BaseModel):
|
117
125
|
try:
|
118
126
|
data = json.loads(result_as_str)
|
vellum/workflows/state/base.py
CHANGED
@@ -18,7 +18,13 @@ from vellum.workflows.inputs.base import BaseInputs
|
|
18
18
|
from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
|
19
19
|
from vellum.workflows.types.generics import StateType
|
20
20
|
from vellum.workflows.types.stack import Stack
|
21
|
-
from vellum.workflows.types.utils import
|
21
|
+
from vellum.workflows.types.utils import (
|
22
|
+
datetime_now,
|
23
|
+
deepcopy_with_exclusions,
|
24
|
+
get_class_attr_names,
|
25
|
+
get_class_by_qualname,
|
26
|
+
infer_types,
|
27
|
+
)
|
22
28
|
|
23
29
|
if TYPE_CHECKING:
|
24
30
|
from vellum.workflows.nodes.bases import BaseNode
|
@@ -47,6 +53,20 @@ class _BaseStateMeta(type):
|
|
47
53
|
|
48
54
|
return super().__getattribute__(name)
|
49
55
|
|
56
|
+
def __iter__(cls) -> Iterator[StateValueReference]:
|
57
|
+
# We iterate through the inheritance hierarchy to find all the StateValueReference attached to this
|
58
|
+
# Inputs class. __mro__ is the method resolution order, which is the order in which base classes are resolved.
|
59
|
+
for resolved_cls in cls.__mro__:
|
60
|
+
attr_names = get_class_attr_names(resolved_cls)
|
61
|
+
for attr_name in attr_names:
|
62
|
+
if attr_name == "meta":
|
63
|
+
continue
|
64
|
+
attr_value = getattr(resolved_cls, attr_name)
|
65
|
+
if not isinstance(attr_value, (StateValueReference)):
|
66
|
+
continue
|
67
|
+
|
68
|
+
yield attr_value
|
69
|
+
|
50
70
|
|
51
71
|
class _SnapshottableDict(dict, _Snapshottable):
|
52
72
|
def __setitem__(self, key: Any, value: Any) -> None:
|