vellum-ai 0.13.28__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/prompts/blocks/compilation.py +23 -16
- vellum/prompts/blocks/tests/test_compilation.py +29 -0
- vellum/utils/templating/render.py +2 -0
- vellum/workflows/constants.py +8 -3
- vellum/workflows/descriptors/tests/test_utils.py +21 -0
- vellum/workflows/descriptors/utils.py +3 -3
- vellum/workflows/errors/types.py +4 -1
- vellum/workflows/expressions/coalesce_expression.py +2 -2
- vellum/workflows/expressions/contains.py +4 -3
- vellum/workflows/expressions/does_not_contain.py +2 -1
- vellum/workflows/expressions/is_nil.py +2 -2
- vellum/workflows/expressions/is_not_nil.py +2 -2
- vellum/workflows/expressions/is_not_undefined.py +2 -2
- vellum/workflows/expressions/is_undefined.py +2 -2
- vellum/workflows/nodes/bases/base.py +19 -3
- vellum/workflows/nodes/bases/tests/test_base_node.py +84 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +3 -3
- vellum/workflows/nodes/core/map_node/node.py +5 -0
- vellum/workflows/nodes/core/map_node/tests/test_node.py +22 -0
- vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +39 -1
- vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +68 -2
- vellum/workflows/nodes/displayable/code_execution_node/utils.py +30 -7
- vellum/workflows/nodes/utils.py +9 -1
- vellum/workflows/outputs/base.py +21 -19
- vellum/workflows/references/external_input.py +2 -2
- vellum/workflows/references/lazy.py +2 -2
- vellum/workflows/references/output.py +7 -7
- vellum/workflows/runner/runner.py +20 -15
- vellum/workflows/state/base.py +23 -3
- vellum/workflows/state/tests/test_state.py +7 -11
- vellum/workflows/workflows/base.py +20 -0
- vellum/workflows/workflows/tests/__init__.py +0 -0
- vellum/workflows/workflows/tests/test_base_workflow.py +80 -0
- {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/RECORD +67 -62
- vellum_ee/workflows/display/base.py +14 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +13 -24
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +52 -0
- vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +243 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +1 -0
- vellum_ee/workflows/display/types.py +5 -1
- vellum_ee/workflows/display/utils/expressions.py +26 -0
- vellum_ee/workflows/display/utils/vellum.py +5 -0
- vellum_ee/workflows/display/vellum.py +14 -0
- vellum_ee/workflows/display/workflows/base_workflow_display.py +30 -1
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +41 -0
- {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.
|
21
|
+
"X-Fern-SDK-Version": "0.14.1",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import
|
2
|
+
from typing import Sequence, Union, cast
|
3
3
|
|
4
4
|
from vellum import (
|
5
5
|
ChatMessage,
|
@@ -14,6 +14,7 @@ from vellum.client.types.audio_vellum_value import AudioVellumValue
|
|
14
14
|
from vellum.client.types.function_call import FunctionCall
|
15
15
|
from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
|
16
16
|
from vellum.client.types.image_vellum_value import ImageVellumValue
|
17
|
+
from vellum.client.types.number_input import NumberInput
|
17
18
|
from vellum.client.types.vellum_audio import VellumAudio
|
18
19
|
from vellum.client.types.vellum_image import VellumImage
|
19
20
|
from vellum.prompts.blocks.exceptions import PromptCompilationError
|
@@ -22,15 +23,20 @@ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS
|
|
22
23
|
from vellum.utils.templating.render import render_sandboxed_jinja_template
|
23
24
|
from vellum.utils.typing import cast_not_optional
|
24
25
|
|
26
|
+
PromptInput = Union[PromptRequestInput, NumberInput]
|
27
|
+
|
25
28
|
|
26
29
|
def compile_prompt_blocks(
|
27
30
|
blocks: list[PromptBlock],
|
28
|
-
inputs:
|
31
|
+
inputs: Sequence[PromptInput],
|
29
32
|
input_variables: list[VellumVariable],
|
30
33
|
) -> list[CompiledPromptBlock]:
|
31
34
|
"""Compiles a list of Prompt Blocks, performing all variable substitutions and Jinja templating needed."""
|
32
35
|
|
33
36
|
sanitized_inputs = _sanitize_inputs(inputs)
|
37
|
+
inputs_by_name = {
|
38
|
+
(input_.name if isinstance(input_, NumberInput) else input_.key): input_ for input_ in sanitized_inputs
|
39
|
+
}
|
34
40
|
|
35
41
|
compiled_blocks: list[CompiledPromptBlock] = []
|
36
42
|
for block in blocks:
|
@@ -66,7 +72,7 @@ def compile_prompt_blocks(
|
|
66
72
|
|
67
73
|
rendered_template = render_sandboxed_jinja_template(
|
68
74
|
template=block.template,
|
69
|
-
input_values={
|
75
|
+
input_values={name: inp.value for name, inp in inputs_by_name.items()},
|
70
76
|
jinja_custom_filters=DEFAULT_JINJA_CUSTOM_FILTERS,
|
71
77
|
jinja_globals=DEFAULT_JINJA_CUSTOM_FILTERS,
|
72
78
|
)
|
@@ -80,9 +86,7 @@ def compile_prompt_blocks(
|
|
80
86
|
)
|
81
87
|
|
82
88
|
elif block.block_type == "VARIABLE":
|
83
|
-
compiled_input
|
84
|
-
(input_ for input_ in sanitized_inputs if input_.key == str(block.input_variable)), None
|
85
|
-
)
|
89
|
+
compiled_input = inputs_by_name.get(block.input_variable)
|
86
90
|
if compiled_input is None:
|
87
91
|
raise PromptCompilationError(f"Input variable '{block.input_variable}' not found")
|
88
92
|
|
@@ -196,23 +200,26 @@ def _compile_chat_messages_as_prompt_blocks(chat_messages: list[ChatMessage]) ->
|
|
196
200
|
|
197
201
|
def _compile_rich_text_block_as_value_block(
|
198
202
|
block: RichTextPromptBlock,
|
199
|
-
inputs: list[
|
203
|
+
inputs: list[PromptInput],
|
200
204
|
) -> CompiledValuePromptBlock:
|
201
205
|
value: str = ""
|
206
|
+
inputs_by_name = {(input_.name if isinstance(input_, NumberInput) else input_.key): input_ for input_ in inputs}
|
202
207
|
for child_block in block.blocks:
|
203
208
|
if child_block.block_type == "PLAIN_TEXT":
|
204
209
|
value += child_block.text
|
205
210
|
elif child_block.block_type == "VARIABLE":
|
206
|
-
|
207
|
-
if
|
211
|
+
input = inputs_by_name.get(child_block.input_variable)
|
212
|
+
if input is None:
|
208
213
|
raise PromptCompilationError(f"Input variable '{child_block.input_variable}' not found")
|
209
|
-
elif
|
210
|
-
value += str(
|
211
|
-
elif
|
212
|
-
value += json.dumps(
|
214
|
+
elif input.type == "STRING":
|
215
|
+
value += str(input.value)
|
216
|
+
elif input.type == "JSON":
|
217
|
+
value += json.dumps(input.value, indent=4)
|
218
|
+
elif input.type == "NUMBER":
|
219
|
+
value += str(input.value)
|
213
220
|
else:
|
214
221
|
raise PromptCompilationError(
|
215
|
-
f"Input variable '{child_block.input_variable}'
|
222
|
+
f"Input variable '{child_block.input_variable}' has an invalid type: {input.type}"
|
216
223
|
)
|
217
224
|
else:
|
218
225
|
raise PromptCompilationError(f"Invalid child block_type for RICH_TEXT: {child_block.block_type}")
|
@@ -220,8 +227,8 @@ def _compile_rich_text_block_as_value_block(
|
|
220
227
|
return CompiledValuePromptBlock(content=StringVellumValue(value=value), cache_config=block.cache_config)
|
221
228
|
|
222
229
|
|
223
|
-
def _sanitize_inputs(inputs:
|
224
|
-
sanitized_inputs: list[
|
230
|
+
def _sanitize_inputs(inputs: Sequence[PromptInput]) -> list[PromptInput]:
|
231
|
+
sanitized_inputs: list[PromptInput] = []
|
225
232
|
for input_ in inputs:
|
226
233
|
if input_.type == "CHAT_HISTORY" and input_.value is None:
|
227
234
|
sanitized_inputs.append(input_.model_copy(update={"value": cast(list[ChatMessage], [])}))
|
@@ -10,6 +10,7 @@ from vellum import (
|
|
10
10
|
VariablePromptBlock,
|
11
11
|
VellumVariable,
|
12
12
|
)
|
13
|
+
from vellum.client.types.number_input import NumberInput
|
13
14
|
from vellum.prompts.blocks.compilation import compile_prompt_blocks
|
14
15
|
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
|
15
16
|
|
@@ -94,6 +95,33 @@ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, Compiled
|
|
94
95
|
),
|
95
96
|
],
|
96
97
|
),
|
98
|
+
(
|
99
|
+
[
|
100
|
+
ChatMessagePromptBlock(
|
101
|
+
chat_role="USER",
|
102
|
+
blocks=[
|
103
|
+
RichTextPromptBlock(
|
104
|
+
blocks=[
|
105
|
+
PlainTextPromptBlock(text="Count to "),
|
106
|
+
VariablePromptBlock(input_variable="count"),
|
107
|
+
]
|
108
|
+
)
|
109
|
+
],
|
110
|
+
)
|
111
|
+
],
|
112
|
+
[
|
113
|
+
# TODO: We don't yet have PromptRequestNumberInput. We should either add it or migrate
|
114
|
+
# Prompts to using these more generic inputs.
|
115
|
+
NumberInput(name="count", value=10),
|
116
|
+
],
|
117
|
+
[VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="NUMBER", key="count")],
|
118
|
+
[
|
119
|
+
CompiledChatMessagePromptBlock(
|
120
|
+
role="USER",
|
121
|
+
blocks=[CompiledValuePromptBlock(content=StringVellumValue(value="Count to 10.0"))],
|
122
|
+
),
|
123
|
+
],
|
124
|
+
),
|
97
125
|
],
|
98
126
|
ids=[
|
99
127
|
"empty",
|
@@ -102,6 +130,7 @@ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, Compiled
|
|
102
130
|
"rich-text-no-variables",
|
103
131
|
"rich-text-with-variables",
|
104
132
|
"chat-message",
|
133
|
+
"number-input",
|
105
134
|
],
|
106
135
|
)
|
107
136
|
def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected):
|
@@ -42,6 +42,8 @@ def render_sandboxed_jinja_template(
|
|
42
42
|
|
43
43
|
rendered_template = jinja_template.render(input_values)
|
44
44
|
except json.JSONDecodeError as e:
|
45
|
+
if not e.doc:
|
46
|
+
raise JinjaTemplateError("Unable to render jinja template:\n" "Cannot run json.loads() on empty input")
|
45
47
|
if e.msg == "Invalid control character at":
|
46
48
|
raise JinjaTemplateError(
|
47
49
|
"Unable to render jinja template:\n"
|
vellum/workflows/constants.py
CHANGED
@@ -4,11 +4,11 @@ from typing import Any, cast
|
|
4
4
|
|
5
5
|
class _UndefMeta(type):
|
6
6
|
def __repr__(cls) -> str:
|
7
|
-
return "
|
7
|
+
return "undefined"
|
8
8
|
|
9
9
|
def __getattribute__(cls, name: str) -> Any:
|
10
10
|
if name == "__class__":
|
11
|
-
# ensures that
|
11
|
+
# ensures that undefined.__class__ == undefined
|
12
12
|
return cls
|
13
13
|
|
14
14
|
return super().__getattribute__(name)
|
@@ -17,7 +17,12 @@ class _UndefMeta(type):
|
|
17
17
|
return False
|
18
18
|
|
19
19
|
|
20
|
-
class
|
20
|
+
class undefined(metaclass=_UndefMeta):
|
21
|
+
"""
|
22
|
+
A singleton class that represents an `undefined` value, mirroring the behavior of the `undefined`
|
23
|
+
value in TypeScript.
|
24
|
+
"""
|
25
|
+
|
21
26
|
pass
|
22
27
|
|
23
28
|
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import pytest
|
2
2
|
|
3
3
|
from vellum.workflows.descriptors.utils import resolve_value
|
4
|
+
from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
|
4
5
|
from vellum.workflows.nodes.bases.base import BaseNode
|
5
6
|
from vellum.workflows.references.constant import ConstantValueReference
|
6
7
|
from vellum.workflows.state.base import BaseState
|
@@ -77,6 +78,24 @@ class DummyNode(BaseNode[FixtureState]):
|
|
77
78
|
(FixtureState.zeta["foo"], "bar"),
|
78
79
|
(ConstantValueReference(1), 1),
|
79
80
|
(FixtureState.theta[0], "baz"),
|
81
|
+
(
|
82
|
+
ConstantValueReference(
|
83
|
+
WorkflowError(
|
84
|
+
message="This is a test",
|
85
|
+
code=WorkflowErrorCode.USER_DEFINED_ERROR,
|
86
|
+
)
|
87
|
+
).contains("test"),
|
88
|
+
True,
|
89
|
+
),
|
90
|
+
(
|
91
|
+
ConstantValueReference(
|
92
|
+
WorkflowError(
|
93
|
+
message="This is a test",
|
94
|
+
code=WorkflowErrorCode.USER_DEFINED_ERROR,
|
95
|
+
)
|
96
|
+
).does_not_contain("test"),
|
97
|
+
False,
|
98
|
+
),
|
80
99
|
],
|
81
100
|
ids=[
|
82
101
|
"or",
|
@@ -122,6 +141,8 @@ class DummyNode(BaseNode[FixtureState]):
|
|
122
141
|
"accessor",
|
123
142
|
"constants",
|
124
143
|
"list_index",
|
144
|
+
"error_contains",
|
145
|
+
"error_does_not_contain",
|
125
146
|
],
|
126
147
|
)
|
127
148
|
def test_resolve_value__happy_path(descriptor, expected_value):
|
@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, ove
|
|
5
5
|
|
6
6
|
from pydantic import BaseModel
|
7
7
|
|
8
|
-
from vellum.workflows.constants import
|
8
|
+
from vellum.workflows.constants import undefined
|
9
9
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
10
10
|
from vellum.workflows.state.base import BaseState
|
11
11
|
|
@@ -93,10 +93,10 @@ def resolve_value(
|
|
93
93
|
|
94
94
|
def is_unresolved(value: Any) -> bool:
|
95
95
|
"""
|
96
|
-
Recursively checks if a value has an unresolved value, represented by
|
96
|
+
Recursively checks if a value has an unresolved value, represented by undefined.
|
97
97
|
"""
|
98
98
|
|
99
|
-
if value is
|
99
|
+
if value is undefined:
|
100
100
|
return True
|
101
101
|
|
102
102
|
if dataclasses.is_dataclass(value):
|
vellum/workflows/errors/types.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
from enum import Enum
|
3
|
-
from typing import Dict
|
3
|
+
from typing import Any, Dict
|
4
4
|
|
5
5
|
from vellum.client.types.vellum_error import VellumError
|
6
6
|
from vellum.client.types.vellum_error_code_enum import VellumErrorCodeEnum
|
@@ -26,6 +26,9 @@ class WorkflowError:
|
|
26
26
|
message: str
|
27
27
|
code: WorkflowErrorCode
|
28
28
|
|
29
|
+
def __contains__(self, item: Any) -> bool:
|
30
|
+
return item in self.message
|
31
|
+
|
29
32
|
|
30
33
|
_VELLUM_ERROR_CODE_TO_WORKFLOW_ERROR_CODE: Dict[VellumErrorCodeEnum, WorkflowErrorCode] = {
|
31
34
|
"INVALID_REQUEST": WorkflowErrorCode.INVALID_INPUTS,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import TypeVar, Union
|
2
2
|
|
3
|
-
from vellum.workflows.constants import
|
3
|
+
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
5
|
from vellum.workflows.descriptors.utils import resolve_value
|
6
6
|
from vellum.workflows.state.base import BaseState
|
@@ -27,7 +27,7 @@ class CoalesceExpression(BaseDescriptor[Union[LHS, RHS]]):
|
|
27
27
|
|
28
28
|
def resolve(self, state: "BaseState") -> Union[LHS, RHS]:
|
29
29
|
lhs = resolve_value(self._lhs, state)
|
30
|
-
if lhs is not
|
30
|
+
if lhs is not undefined and lhs is not None:
|
31
31
|
return lhs
|
32
32
|
|
33
33
|
return resolve_value(self._rhs, state)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
from typing import Generic, TypeVar, Union
|
2
2
|
|
3
|
-
from vellum.workflows.constants import
|
3
|
+
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
5
|
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
6
6
|
from vellum.workflows.descriptors.utils import resolve_value
|
7
|
+
from vellum.workflows.errors.types import WorkflowError
|
7
8
|
from vellum.workflows.state.base import BaseState
|
8
9
|
|
9
10
|
LHS = TypeVar("LHS")
|
@@ -26,9 +27,9 @@ class ContainsExpression(BaseDescriptor[bool], Generic[LHS, RHS]):
|
|
26
27
|
# https://app.shortcut.com/vellum/story/4658
|
27
28
|
lhs = resolve_value(self._lhs, state)
|
28
29
|
# assumes that lack of is also false
|
29
|
-
if lhs is
|
30
|
+
if lhs is undefined:
|
30
31
|
return False
|
31
|
-
if not isinstance(lhs, (list, tuple, set, dict, str)):
|
32
|
+
if not isinstance(lhs, (list, tuple, set, dict, str, WorkflowError)):
|
32
33
|
raise InvalidExpressionException(
|
33
34
|
f"Expected a LHS that supported `contains`, got `{lhs.__class__.__name__}`"
|
34
35
|
)
|
@@ -3,6 +3,7 @@ from typing import Generic, TypeVar, Union
|
|
3
3
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
4
4
|
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
5
5
|
from vellum.workflows.descriptors.utils import resolve_value
|
6
|
+
from vellum.workflows.errors.types import WorkflowError
|
6
7
|
from vellum.workflows.state.base import BaseState
|
7
8
|
|
8
9
|
LHS = TypeVar("LHS")
|
@@ -24,7 +25,7 @@ class DoesNotContainExpression(BaseDescriptor[bool], Generic[LHS, RHS]):
|
|
24
25
|
# Support any type that implements the not in operator
|
25
26
|
# https://app.shortcut.com/vellum/story/4658
|
26
27
|
lhs = resolve_value(self._lhs, state)
|
27
|
-
if not isinstance(lhs, (list, tuple, set, dict, str)):
|
28
|
+
if not isinstance(lhs, (list, tuple, set, dict, str, WorkflowError)):
|
28
29
|
raise InvalidExpressionException(
|
29
30
|
f"Expected a LHS that supported `contains`, got `{lhs.__class__.__name__}`"
|
30
31
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Generic, TypeVar, Union
|
2
2
|
|
3
|
-
from vellum.workflows.constants import
|
3
|
+
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
5
|
from vellum.workflows.descriptors.utils import resolve_value
|
6
6
|
from vellum.workflows.state.base import BaseState
|
@@ -19,4 +19,4 @@ class IsNilExpression(BaseDescriptor[bool], Generic[_T]):
|
|
19
19
|
|
20
20
|
def resolve(self, state: "BaseState") -> bool:
|
21
21
|
expression = resolve_value(self._expression, state)
|
22
|
-
return expression is None or expression is
|
22
|
+
return expression is None or expression is undefined
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Generic, TypeVar, Union
|
2
2
|
|
3
|
-
from vellum.workflows.constants import
|
3
|
+
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
5
|
from vellum.workflows.descriptors.utils import resolve_value
|
6
6
|
from vellum.workflows.state.base import BaseState
|
@@ -19,4 +19,4 @@ class IsNotNilExpression(BaseDescriptor[bool], Generic[_T]):
|
|
19
19
|
|
20
20
|
def resolve(self, state: "BaseState") -> bool:
|
21
21
|
expression = resolve_value(self._expression, state)
|
22
|
-
return expression is not None and expression is not
|
22
|
+
return expression is not None and expression is not undefined
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Generic, TypeVar, Union
|
2
2
|
|
3
|
-
from vellum.workflows.constants import
|
3
|
+
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
5
|
from vellum.workflows.descriptors.utils import resolve_value
|
6
6
|
from vellum.workflows.state.base import BaseState
|
@@ -19,4 +19,4 @@ class IsNotUndefinedExpression(BaseDescriptor[bool], Generic[_T]):
|
|
19
19
|
|
20
20
|
def resolve(self, state: "BaseState") -> bool:
|
21
21
|
expression = resolve_value(self._expression, state)
|
22
|
-
return expression is not
|
22
|
+
return expression is not undefined
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Generic, TypeVar, Union
|
2
2
|
|
3
|
-
from vellum.workflows.constants import
|
3
|
+
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
5
|
from vellum.workflows.descriptors.utils import resolve_value
|
6
6
|
from vellum.workflows.state.base import BaseState
|
@@ -19,4 +19,4 @@ class IsUndefinedExpression(BaseDescriptor[bool], Generic[_T]):
|
|
19
19
|
|
20
20
|
def resolve(self, state: "BaseState") -> bool:
|
21
21
|
expression = resolve_value(self._expression, state)
|
22
|
-
return expression is
|
22
|
+
return expression is undefined
|
@@ -5,7 +5,7 @@ from types import MappingProxyType
|
|
5
5
|
from uuid import UUID
|
6
6
|
from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union, cast, get_args
|
7
7
|
|
8
|
-
from vellum.workflows.constants import
|
8
|
+
from vellum.workflows.constants import undefined
|
9
9
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
10
10
|
from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
|
11
11
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -43,7 +43,23 @@ class BaseNodeMeta(type):
|
|
43
43
|
# TODO: Inherit the inner Output classes from every base class.
|
44
44
|
# https://app.shortcut.com/vellum/story/4007/support-auto-inheriting-parent-node-outputs
|
45
45
|
|
46
|
-
if "Outputs"
|
46
|
+
if "Outputs" in dct:
|
47
|
+
outputs_class = dct["Outputs"]
|
48
|
+
if not any(issubclass(base, BaseOutputs) for base in outputs_class.__bases__):
|
49
|
+
parent_outputs_class = next(
|
50
|
+
(base.Outputs for base in bases if hasattr(base, "Outputs")),
|
51
|
+
BaseOutputs, # Default to BaseOutputs only if no parent has Outputs
|
52
|
+
)
|
53
|
+
|
54
|
+
# Filter out object from bases while preserving other inheritance
|
55
|
+
filtered_bases = tuple(base for base in outputs_class.__bases__ if base is not object)
|
56
|
+
|
57
|
+
dct["Outputs"] = type(
|
58
|
+
f"{name}.Outputs",
|
59
|
+
(parent_outputs_class,) + filtered_bases,
|
60
|
+
{**outputs_class.__dict__, "__module__": dct["__module__"]},
|
61
|
+
)
|
62
|
+
else:
|
47
63
|
for base in reversed(bases):
|
48
64
|
if hasattr(base, "Outputs"):
|
49
65
|
dct["Outputs"] = type(
|
@@ -165,7 +181,7 @@ class BaseNodeMeta(type):
|
|
165
181
|
if attr_name in yielded_attr_names:
|
166
182
|
continue
|
167
183
|
|
168
|
-
attr_value = getattr(resolved_cls, attr_name,
|
184
|
+
attr_value = getattr(resolved_cls, attr_name, undefined)
|
169
185
|
if not isinstance(attr_value, NodeReference):
|
170
186
|
continue
|
171
187
|
|
@@ -5,6 +5,7 @@ from vellum.client.types.string_vellum_value_request import StringVellumValueReq
|
|
5
5
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
6
6
|
from vellum.workflows.inputs.base import BaseInputs
|
7
7
|
from vellum.workflows.nodes.bases.base import BaseNode
|
8
|
+
from vellum.workflows.outputs.base import BaseOutputs
|
8
9
|
from vellum.workflows.state.base import BaseState, StateMeta
|
9
10
|
|
10
11
|
|
@@ -148,3 +149,86 @@ def test_base_node__node_resolution__descriptor_in_fern_pydantic():
|
|
148
149
|
node = SomeNode(state=State(foo="bar"))
|
149
150
|
|
150
151
|
assert node.model.value == "bar"
|
152
|
+
|
153
|
+
|
154
|
+
def test_base_node__inherit_base_outputs():
|
155
|
+
class MyNode(BaseNode):
|
156
|
+
class Outputs:
|
157
|
+
foo: str
|
158
|
+
|
159
|
+
def run(self):
|
160
|
+
return self.Outputs(foo="bar") # type: ignore
|
161
|
+
|
162
|
+
# TEST that the Outputs class is a subclass of BaseOutputs
|
163
|
+
assert issubclass(MyNode.Outputs, BaseOutputs)
|
164
|
+
|
165
|
+
# TEST that the Outputs class does not inherit from object
|
166
|
+
assert object not in MyNode.Outputs.__bases__
|
167
|
+
|
168
|
+
# TEST that the Outputs class has the correct attributes
|
169
|
+
assert hasattr(MyNode.Outputs, "foo")
|
170
|
+
|
171
|
+
# WHEN the node is run
|
172
|
+
node = MyNode()
|
173
|
+
outputs = node.run()
|
174
|
+
|
175
|
+
# THEN the outputs should be correct
|
176
|
+
assert outputs.foo == "bar"
|
177
|
+
|
178
|
+
|
179
|
+
def test_child_node__inherits_base_outputs_when_no_parent_outputs():
|
180
|
+
class ParentNode(BaseNode): # No Outputs class here
|
181
|
+
pass
|
182
|
+
|
183
|
+
class ChildNode(ParentNode):
|
184
|
+
class Outputs:
|
185
|
+
foo: str
|
186
|
+
|
187
|
+
def run(self):
|
188
|
+
return self.Outputs(foo="bar") # type: ignore
|
189
|
+
|
190
|
+
# TEST that ChildNode.Outputs is a subclass of BaseOutputs (since ParentNode has no Outputs)
|
191
|
+
assert issubclass(ChildNode.Outputs, BaseOutputs)
|
192
|
+
|
193
|
+
# TEST that ChildNode.Outputs has the correct attributes
|
194
|
+
assert hasattr(ChildNode.Outputs, "foo")
|
195
|
+
|
196
|
+
# WHEN the node is run
|
197
|
+
node = ChildNode()
|
198
|
+
outputs = node.run()
|
199
|
+
|
200
|
+
# THEN the outputs should be correct
|
201
|
+
assert outputs.foo == "bar"
|
202
|
+
|
203
|
+
|
204
|
+
def test_outputs_preserves_non_object_bases():
|
205
|
+
class ParentNode(BaseNode):
|
206
|
+
class Outputs:
|
207
|
+
foo: str
|
208
|
+
|
209
|
+
class Foo:
|
210
|
+
bar: str
|
211
|
+
|
212
|
+
class ChildNode(ParentNode):
|
213
|
+
class Outputs(ParentNode.Outputs, Foo):
|
214
|
+
pass
|
215
|
+
|
216
|
+
def run(self):
|
217
|
+
return self.Outputs(foo="bar", bar="baz") # type: ignore
|
218
|
+
|
219
|
+
# TEST that Outputs is a subclass of Foo and ParentNode.Outputs
|
220
|
+
assert Foo in ChildNode.Outputs.__bases__, "Foo should be preserved in bases"
|
221
|
+
assert ParentNode.Outputs in ChildNode.Outputs.__bases__, "ParentNode.Outputs should be preserved in bases"
|
222
|
+
assert object not in ChildNode.Outputs.__bases__, "object should not be in bases"
|
223
|
+
|
224
|
+
# TEST that Outputs has the correct attributes
|
225
|
+
assert hasattr(ChildNode.Outputs, "foo")
|
226
|
+
assert hasattr(ChildNode.Outputs, "bar")
|
227
|
+
|
228
|
+
# WHEN Outputs is instantiated
|
229
|
+
node = ChildNode()
|
230
|
+
outputs = node.run()
|
231
|
+
|
232
|
+
# THEN the output values should be correct
|
233
|
+
assert outputs.foo == "bar"
|
234
|
+
assert outputs.bar == "baz"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union
|
2
2
|
|
3
|
-
from vellum.workflows.constants import
|
3
|
+
from vellum.workflows.constants import undefined
|
4
4
|
from vellum.workflows.context import execution_context, get_parent_context
|
5
5
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
6
6
|
from vellum.workflows.exceptions import NodeException
|
@@ -67,7 +67,7 @@ class InlineSubworkflowNode(
|
|
67
67
|
"""
|
68
68
|
|
69
69
|
subworkflow: Type["BaseWorkflow[InputsType, InnerStateType]"]
|
70
|
-
subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs, Type[
|
70
|
+
subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs, Type[undefined]]] = undefined
|
71
71
|
|
72
72
|
def run(self) -> Iterator[BaseOutput]:
|
73
73
|
with execution_context(parent_context=get_parent_context() or self._context.parent_context):
|
@@ -112,7 +112,7 @@ class InlineSubworkflowNode(
|
|
112
112
|
|
113
113
|
def _compile_subworkflow_inputs(self) -> InputsType:
|
114
114
|
inputs_class = self.subworkflow.get_inputs_class()
|
115
|
-
if self.subworkflow_inputs is
|
115
|
+
if self.subworkflow_inputs is undefined:
|
116
116
|
inputs_dict = {}
|
117
117
|
for descriptor in inputs_class:
|
118
118
|
if hasattr(self, descriptor.name):
|
@@ -66,6 +66,11 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
66
66
|
for output_descripter in self.subworkflow.Outputs:
|
67
67
|
mapped_items[output_descripter.name] = [None] * len(self.items)
|
68
68
|
|
69
|
+
if not self.items:
|
70
|
+
for output_name, output_list in mapped_items.items():
|
71
|
+
yield BaseOutput(name=output_name, value=output_list)
|
72
|
+
return
|
73
|
+
|
69
74
|
self._event_queue: Queue[Tuple[int, WorkflowEvent]] = Queue()
|
70
75
|
self._concurrency_queue: Queue[Thread] = Queue()
|
71
76
|
fulfilled_iterations: List[bool] = []
|
@@ -63,3 +63,25 @@ def test_map_node__use_parallelism():
|
|
63
63
|
# THEN the node should have ran in parallel
|
64
64
|
run_time = (end_ts - start_ts) / 10**9
|
65
65
|
assert run_time < 0.2
|
66
|
+
|
67
|
+
|
68
|
+
def test_map_node__empty_list():
|
69
|
+
# GIVEN a map node that is configured to use the parent's inputs and state
|
70
|
+
@MapNode.wrap(items=[])
|
71
|
+
class TestNode(BaseNode):
|
72
|
+
item = MapNode.SubworkflowInputs.item
|
73
|
+
|
74
|
+
class Outputs(BaseOutputs):
|
75
|
+
value: int
|
76
|
+
|
77
|
+
def run(self) -> Outputs:
|
78
|
+
time.sleep(0.03)
|
79
|
+
return self.Outputs(value=self.item + 1)
|
80
|
+
|
81
|
+
# WHEN the node is run
|
82
|
+
node = TestNode()
|
83
|
+
outputs = list(node.run())
|
84
|
+
|
85
|
+
# THEN the node should return an empty output
|
86
|
+
fulfilled_output = outputs[-1]
|
87
|
+
assert fulfilled_output == BaseOutput(name="value", value=[])
|
@@ -1,8 +1,10 @@
|
|
1
|
+
import pytest
|
1
2
|
import json
|
2
|
-
from typing import List
|
3
|
+
from typing import List, Union
|
3
4
|
|
4
5
|
from vellum.client.types.chat_message import ChatMessage
|
5
6
|
from vellum.client.types.function_call import FunctionCall
|
7
|
+
from vellum.workflows.exceptions import NodeException
|
6
8
|
from vellum.workflows.nodes.bases.base import BaseNode
|
7
9
|
from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
|
8
10
|
from vellum.workflows.state import BaseState
|
@@ -155,3 +157,39 @@ def test_templating_node__function_call_output():
|
|
155
157
|
|
156
158
|
# THEN the output is the expected function call
|
157
159
|
assert outputs.result == FunctionCall(name="test", arguments={"key": "value"})
|
160
|
+
|
161
|
+
|
162
|
+
def test_templating_node__blank_json_input():
|
163
|
+
"""Test that templating node properly handles blank JSON input."""
|
164
|
+
|
165
|
+
# GIVEN a templating node that tries to parse blank JSON
|
166
|
+
class BlankJsonTemplateNode(TemplatingNode[BaseState, Json]):
|
167
|
+
template = "{{ json.loads(data) }}"
|
168
|
+
inputs = {
|
169
|
+
"data": "", # Blank input
|
170
|
+
}
|
171
|
+
|
172
|
+
# WHEN the node is run
|
173
|
+
node = BlankJsonTemplateNode()
|
174
|
+
|
175
|
+
# THEN it should raise an appropriate error
|
176
|
+
with pytest.raises(NodeException) as exc_info:
|
177
|
+
node.run()
|
178
|
+
|
179
|
+
assert "Unable to render jinja template:\nCannot run json.loads() on empty input" in str(exc_info.value)
|
180
|
+
|
181
|
+
|
182
|
+
def test_templating_node__union_float_int_output():
|
183
|
+
# GIVEN a templating node that outputs either a float or an int
|
184
|
+
class UnionTemplateNode(TemplatingNode[BaseState, Union[float, int]]):
|
185
|
+
template = """{{ obj[\"score\"] | float }}"""
|
186
|
+
inputs = {
|
187
|
+
"obj": {"score": 42.5},
|
188
|
+
}
|
189
|
+
|
190
|
+
# WHEN the node is run
|
191
|
+
node = UnionTemplateNode()
|
192
|
+
outputs = node.run()
|
193
|
+
|
194
|
+
# THEN it should correctly parse as a float
|
195
|
+
assert outputs.result == 42.5
|