vellum-ai 0.12.7__py3-none-any.whl → 0.12.9__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/evaluations/resources.py +2 -2
- vellum/prompts/__init__.py +0 -0
- vellum/prompts/blocks/__init__.py +0 -0
- vellum/prompts/blocks/compilation.py +190 -0
- vellum/prompts/blocks/exceptions.py +2 -0
- vellum/prompts/blocks/tests/__init__.py +0 -0
- vellum/prompts/blocks/tests/test_compilation.py +110 -0
- vellum/prompts/blocks/types.py +36 -0
- vellum/utils/__init__.py +0 -0
- vellum/utils/templating/__init__.py +0 -0
- vellum/utils/templating/constants.py +28 -0
- vellum/{workflows/nodes/core/templating_node → utils/templating}/render.py +1 -1
- vellum/workflows/nodes/bases/__init__.py +0 -2
- vellum/workflows/nodes/bases/base.py +2 -6
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +13 -7
- vellum/workflows/nodes/core/inline_subworkflow_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +41 -0
- vellum/workflows/nodes/core/map_node/node.py +1 -1
- vellum/workflows/nodes/core/templating_node/node.py +12 -31
- vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +66 -0
- vellum/workflows/nodes/core/try_node/node.py +1 -3
- vellum/workflows/nodes/core/try_node/tests/test_node.py +1 -1
- vellum/workflows/nodes/displayable/bases/api_node/node.py +2 -2
- vellum/workflows/nodes/displayable/bases/search_node.py +5 -2
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +4 -2
- vellum/workflows/nodes/experimental/README.md +6 -0
- vellum/workflows/nodes/experimental/__init__.py +0 -0
- vellum/workflows/nodes/experimental/openai_chat_completion_node/__init__.py +5 -0
- vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +260 -0
- vellum/workflows/sandbox.py +7 -5
- vellum/workflows/state/context.py +5 -4
- vellum/workflows/tests/test_sandbox.py +2 -2
- vellum/workflows/utils/tests/test_vellum_variables.py +3 -0
- vellum/workflows/utils/vellum_variables.py +5 -4
- vellum/workflows/workflows/base.py +5 -2
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/METADATA +2 -1
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/RECORD +48 -33
- vellum_cli/tests/test_push.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +6 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +207 -0
- vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +0 -5
- vellum/workflows/nodes/bases/base_subworkflow_node/node.py +0 -10
- /vellum/{workflows/nodes/core/templating_node → utils/templating}/custom_filters.py +0 -0
- /vellum/{workflows/nodes/core/templating_node → utils/templating}/exceptions.py +0 -0
- /vellum/{evaluations/utils → utils}/typing.py +0 -0
- /vellum/{evaluations/utils → utils}/uuid.py +0 -0
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/LICENSE +0 -0
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/WHEEL +0 -0
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.12.
|
21
|
+
"X-Fern-SDK-Version": "0.12.9",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
vellum/evaluations/resources.py
CHANGED
@@ -12,8 +12,6 @@ from vellum.evaluations.constants import DEFAULT_MAX_POLLING_DURATION_MS, DEFAUL
|
|
12
12
|
from vellum.evaluations.exceptions import TestSuiteRunResultsException
|
13
13
|
from vellum.evaluations.utils.env import get_api_key
|
14
14
|
from vellum.evaluations.utils.paginator import PaginatedResults, get_all_results
|
15
|
-
from vellum.evaluations.utils.typing import cast_not_optional
|
16
|
-
from vellum.evaluations.utils.uuid import is_valid_uuid
|
17
15
|
from vellum.types import (
|
18
16
|
ExternalTestCaseExecutionRequest,
|
19
17
|
NamedTestCaseVariableValueRequest,
|
@@ -24,6 +22,8 @@ from vellum.types import (
|
|
24
22
|
TestSuiteRunMetricOutput,
|
25
23
|
TestSuiteRunState,
|
26
24
|
)
|
25
|
+
from vellum.utils.typing import cast_not_optional
|
26
|
+
from vellum.utils.uuid import is_valid_uuid
|
27
27
|
|
28
28
|
logger = logging.getLogger(__name__)
|
29
29
|
|
File without changes
|
File without changes
|
@@ -0,0 +1,190 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Optional, cast
|
3
|
+
|
4
|
+
from vellum import (
|
5
|
+
ChatMessage,
|
6
|
+
JsonVellumValue,
|
7
|
+
PromptBlock,
|
8
|
+
PromptRequestInput,
|
9
|
+
RichTextPromptBlock,
|
10
|
+
StringVellumValue,
|
11
|
+
VellumVariable,
|
12
|
+
)
|
13
|
+
from vellum.prompts.blocks.exceptions import PromptCompilationError
|
14
|
+
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledPromptBlock, CompiledValuePromptBlock
|
15
|
+
from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS
|
16
|
+
from vellum.utils.templating.render import render_sandboxed_jinja_template
|
17
|
+
from vellum.utils.typing import cast_not_optional
|
18
|
+
|
19
|
+
|
20
|
+
def compile_prompt_blocks(
|
21
|
+
blocks: list[PromptBlock],
|
22
|
+
inputs: list[PromptRequestInput],
|
23
|
+
input_variables: list[VellumVariable],
|
24
|
+
) -> list[CompiledPromptBlock]:
|
25
|
+
"""Compiles a list of Prompt Blocks, performing all variable substitutions and Jinja templating needed."""
|
26
|
+
|
27
|
+
sanitized_inputs = _sanitize_inputs(inputs)
|
28
|
+
|
29
|
+
compiled_blocks: list[CompiledPromptBlock] = []
|
30
|
+
for block in blocks:
|
31
|
+
if block.state == "DISABLED":
|
32
|
+
continue
|
33
|
+
|
34
|
+
if block.block_type == "CHAT_MESSAGE":
|
35
|
+
chat_role = cast_not_optional(block.chat_role)
|
36
|
+
inner_blocks = cast_not_optional(block.blocks)
|
37
|
+
unterminated = block.chat_message_unterminated or False
|
38
|
+
|
39
|
+
inner_prompt_blocks = compile_prompt_blocks(
|
40
|
+
inner_blocks,
|
41
|
+
sanitized_inputs,
|
42
|
+
input_variables,
|
43
|
+
)
|
44
|
+
if not inner_prompt_blocks:
|
45
|
+
continue
|
46
|
+
|
47
|
+
compiled_blocks.append(
|
48
|
+
CompiledChatMessagePromptBlock(
|
49
|
+
role=chat_role,
|
50
|
+
unterminated=unterminated,
|
51
|
+
source=block.chat_source,
|
52
|
+
blocks=[inner for inner in inner_prompt_blocks if inner.block_type == "VALUE"],
|
53
|
+
cache_config=block.cache_config,
|
54
|
+
)
|
55
|
+
)
|
56
|
+
|
57
|
+
elif block.block_type == "JINJA":
|
58
|
+
if block.template is None:
|
59
|
+
continue
|
60
|
+
|
61
|
+
rendered_template = render_sandboxed_jinja_template(
|
62
|
+
template=block.template,
|
63
|
+
input_values={input_.key: input_.value for input_ in sanitized_inputs},
|
64
|
+
jinja_custom_filters=DEFAULT_JINJA_CUSTOM_FILTERS,
|
65
|
+
jinja_globals=DEFAULT_JINJA_CUSTOM_FILTERS,
|
66
|
+
)
|
67
|
+
jinja_content = StringVellumValue(value=rendered_template)
|
68
|
+
|
69
|
+
compiled_blocks.append(
|
70
|
+
CompiledValuePromptBlock(
|
71
|
+
content=jinja_content,
|
72
|
+
cache_config=block.cache_config,
|
73
|
+
)
|
74
|
+
)
|
75
|
+
|
76
|
+
elif block.block_type == "VARIABLE":
|
77
|
+
compiled_input: Optional[PromptRequestInput] = next(
|
78
|
+
(input_ for input_ in sanitized_inputs if input_.key == str(block.input_variable)), None
|
79
|
+
)
|
80
|
+
if compiled_input is None:
|
81
|
+
raise PromptCompilationError(f"Input variable '{block.input_variable}' not found")
|
82
|
+
|
83
|
+
if compiled_input.type == "CHAT_HISTORY":
|
84
|
+
history = cast(list[ChatMessage], compiled_input.value)
|
85
|
+
chat_message_blocks = _compile_chat_messages_as_prompt_blocks(history)
|
86
|
+
compiled_blocks.extend(chat_message_blocks)
|
87
|
+
continue
|
88
|
+
|
89
|
+
if compiled_input.type == "STRING":
|
90
|
+
compiled_blocks.append(
|
91
|
+
CompiledValuePromptBlock(
|
92
|
+
content=StringVellumValue(value=compiled_input.value),
|
93
|
+
cache_config=block.cache_config,
|
94
|
+
)
|
95
|
+
)
|
96
|
+
elif compiled_input == "JSON":
|
97
|
+
compiled_blocks.append(
|
98
|
+
CompiledValuePromptBlock(
|
99
|
+
content=JsonVellumValue(value=compiled_input.value),
|
100
|
+
cache_config=block.cache_config,
|
101
|
+
)
|
102
|
+
)
|
103
|
+
elif compiled_input.type == "CHAT_HISTORY":
|
104
|
+
chat_message_blocks = _compile_chat_messages_as_prompt_blocks(compiled_input.value)
|
105
|
+
compiled_blocks.extend(chat_message_blocks)
|
106
|
+
else:
|
107
|
+
raise PromptCompilationError(f"Invalid input type for variable block: {compiled_input.type}")
|
108
|
+
|
109
|
+
elif block.block_type == "RICH_TEXT":
|
110
|
+
value_block = _compile_rich_text_block_as_value_block(block=block, inputs=sanitized_inputs)
|
111
|
+
compiled_blocks.append(value_block)
|
112
|
+
|
113
|
+
elif block.block_type == "FUNCTION_DEFINITION":
|
114
|
+
raise PromptCompilationError("Function definitions shouldn't go through compilation process")
|
115
|
+
else:
|
116
|
+
raise PromptCompilationError(f"Unknown block_type: {block.block_type}")
|
117
|
+
|
118
|
+
return compiled_blocks
|
119
|
+
|
120
|
+
|
121
|
+
def _compile_chat_messages_as_prompt_blocks(chat_messages: list[ChatMessage]) -> list[CompiledChatMessagePromptBlock]:
|
122
|
+
blocks: list[CompiledChatMessagePromptBlock] = []
|
123
|
+
for chat_message in chat_messages:
|
124
|
+
if chat_message.content is None:
|
125
|
+
continue
|
126
|
+
|
127
|
+
chat_message_blocks = (
|
128
|
+
[
|
129
|
+
CompiledValuePromptBlock(
|
130
|
+
content=item,
|
131
|
+
)
|
132
|
+
for item in chat_message.content.value
|
133
|
+
]
|
134
|
+
if chat_message.content.type == "ARRAY"
|
135
|
+
else [
|
136
|
+
CompiledValuePromptBlock(
|
137
|
+
content=chat_message.content,
|
138
|
+
)
|
139
|
+
]
|
140
|
+
)
|
141
|
+
|
142
|
+
blocks.append(
|
143
|
+
CompiledChatMessagePromptBlock(
|
144
|
+
role=chat_message.role,
|
145
|
+
unterminated=False,
|
146
|
+
blocks=chat_message_blocks,
|
147
|
+
source=chat_message.source,
|
148
|
+
)
|
149
|
+
)
|
150
|
+
|
151
|
+
return blocks
|
152
|
+
|
153
|
+
|
154
|
+
def _compile_rich_text_block_as_value_block(
|
155
|
+
block: RichTextPromptBlock,
|
156
|
+
inputs: list[PromptRequestInput],
|
157
|
+
) -> CompiledValuePromptBlock:
|
158
|
+
value: str = ""
|
159
|
+
for child_block in block.blocks:
|
160
|
+
if child_block.block_type == "PLAIN_TEXT":
|
161
|
+
value += child_block.text
|
162
|
+
elif child_block.block_type == "VARIABLE":
|
163
|
+
variable = next((input_ for input_ in inputs if input_.key == str(child_block.input_variable)), None)
|
164
|
+
if variable is None:
|
165
|
+
raise PromptCompilationError(f"Input variable '{child_block.input_variable}' not found")
|
166
|
+
elif variable.type == "STRING":
|
167
|
+
value += str(variable.value)
|
168
|
+
elif variable.type == "JSON":
|
169
|
+
value += json.dumps(variable.value, indent=4)
|
170
|
+
else:
|
171
|
+
raise PromptCompilationError(
|
172
|
+
f"Input variable '{child_block.input_variable}' must be of type STRING or JSON"
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
raise PromptCompilationError(f"Invalid child block_type for RICH_TEXT: {child_block.block_type}")
|
176
|
+
|
177
|
+
return CompiledValuePromptBlock(content=StringVellumValue(value=value), cache_config=block.cache_config)
|
178
|
+
|
179
|
+
|
180
|
+
def _sanitize_inputs(inputs: list[PromptRequestInput]) -> list[PromptRequestInput]:
|
181
|
+
sanitized_inputs: list[PromptRequestInput] = []
|
182
|
+
for input_ in inputs:
|
183
|
+
if input_.type == "CHAT_HISTORY" and input_.value is None:
|
184
|
+
sanitized_inputs.append(input_.model_copy(update={"value": cast(list[ChatMessage], [])}))
|
185
|
+
elif input_.type == "STRING" and input_.value is None:
|
186
|
+
sanitized_inputs.append(input_.model_copy(update={"value": ""}))
|
187
|
+
else:
|
188
|
+
sanitized_inputs.append(input_)
|
189
|
+
|
190
|
+
return sanitized_inputs
|
File without changes
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum import (
|
4
|
+
ChatMessagePromptBlock,
|
5
|
+
JinjaPromptBlock,
|
6
|
+
PlainTextPromptBlock,
|
7
|
+
PromptRequestStringInput,
|
8
|
+
RichTextPromptBlock,
|
9
|
+
StringVellumValue,
|
10
|
+
VariablePromptBlock,
|
11
|
+
VellumVariable,
|
12
|
+
)
|
13
|
+
from vellum.prompts.blocks.compilation import compile_prompt_blocks
|
14
|
+
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
|
15
|
+
|
16
|
+
|
17
|
+
@pytest.mark.parametrize(
|
18
|
+
["blocks", "inputs", "input_variables", "expected"],
|
19
|
+
[
|
20
|
+
# Empty
|
21
|
+
([], [], [], []),
|
22
|
+
# Jinja
|
23
|
+
(
|
24
|
+
[JinjaPromptBlock(template="Hello, world!")],
|
25
|
+
[],
|
26
|
+
[],
|
27
|
+
[
|
28
|
+
CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
|
29
|
+
],
|
30
|
+
),
|
31
|
+
(
|
32
|
+
[JinjaPromptBlock(template="Repeat back to me {{ echo }}")],
|
33
|
+
[PromptRequestStringInput(key="echo", value="Hello, world!")],
|
34
|
+
[VellumVariable(id="1", type="STRING", key="echo")],
|
35
|
+
[
|
36
|
+
CompiledValuePromptBlock(content=StringVellumValue(value="Repeat back to me Hello, world!")),
|
37
|
+
],
|
38
|
+
),
|
39
|
+
# Rich Text
|
40
|
+
(
|
41
|
+
[
|
42
|
+
RichTextPromptBlock(
|
43
|
+
blocks=[
|
44
|
+
PlainTextPromptBlock(text="Hello, world!"),
|
45
|
+
]
|
46
|
+
)
|
47
|
+
],
|
48
|
+
[],
|
49
|
+
[],
|
50
|
+
[
|
51
|
+
CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
|
52
|
+
],
|
53
|
+
),
|
54
|
+
(
|
55
|
+
[
|
56
|
+
RichTextPromptBlock(
|
57
|
+
blocks=[
|
58
|
+
PlainTextPromptBlock(text='Repeat back to me "'),
|
59
|
+
VariablePromptBlock(input_variable="echo"),
|
60
|
+
PlainTextPromptBlock(text='".'),
|
61
|
+
]
|
62
|
+
)
|
63
|
+
],
|
64
|
+
[PromptRequestStringInput(key="echo", value="Hello, world!")],
|
65
|
+
[VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
|
66
|
+
[
|
67
|
+
CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".')),
|
68
|
+
],
|
69
|
+
),
|
70
|
+
# Chat Message
|
71
|
+
(
|
72
|
+
[
|
73
|
+
ChatMessagePromptBlock(
|
74
|
+
chat_role="USER",
|
75
|
+
blocks=[
|
76
|
+
RichTextPromptBlock(
|
77
|
+
blocks=[
|
78
|
+
PlainTextPromptBlock(text='Repeat back to me "'),
|
79
|
+
VariablePromptBlock(input_variable="echo"),
|
80
|
+
PlainTextPromptBlock(text='".'),
|
81
|
+
]
|
82
|
+
)
|
83
|
+
],
|
84
|
+
)
|
85
|
+
],
|
86
|
+
[PromptRequestStringInput(key="echo", value="Hello, world!")],
|
87
|
+
[VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
|
88
|
+
[
|
89
|
+
CompiledChatMessagePromptBlock(
|
90
|
+
role="USER",
|
91
|
+
blocks=[
|
92
|
+
CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".'))
|
93
|
+
],
|
94
|
+
),
|
95
|
+
],
|
96
|
+
),
|
97
|
+
],
|
98
|
+
ids=[
|
99
|
+
"empty",
|
100
|
+
"jinja-no-variables",
|
101
|
+
"jinja-with-variables",
|
102
|
+
"rich-text-no-variables",
|
103
|
+
"rich-text-with-variables",
|
104
|
+
"chat-message",
|
105
|
+
],
|
106
|
+
)
|
107
|
+
def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected):
|
108
|
+
actual = compile_prompt_blocks(blocks=blocks, inputs=inputs, input_variables=input_variables)
|
109
|
+
|
110
|
+
assert actual == expected
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Annotated, Literal, Optional, Union
|
4
|
+
|
5
|
+
from vellum import ArrayVellumValue, ChatMessageRole, EphemeralPromptCacheConfig, VellumValue
|
6
|
+
from vellum.client.core import UniversalBaseModel
|
7
|
+
|
8
|
+
|
9
|
+
class BaseCompiledPromptBlock(UniversalBaseModel):
|
10
|
+
cache_config: Optional[EphemeralPromptCacheConfig] = None
|
11
|
+
|
12
|
+
|
13
|
+
class CompiledValuePromptBlock(BaseCompiledPromptBlock):
|
14
|
+
block_type: Literal["VALUE"] = "VALUE"
|
15
|
+
content: VellumValue
|
16
|
+
|
17
|
+
|
18
|
+
class CompiledChatMessagePromptBlock(BaseCompiledPromptBlock):
|
19
|
+
block_type: Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
|
20
|
+
role: ChatMessageRole = "ASSISTANT"
|
21
|
+
unterminated: bool = False
|
22
|
+
blocks: list[CompiledValuePromptBlock] = []
|
23
|
+
source: Optional[str] = None
|
24
|
+
|
25
|
+
|
26
|
+
CompiledPromptBlock = Annotated[
|
27
|
+
Union[
|
28
|
+
CompiledValuePromptBlock,
|
29
|
+
CompiledChatMessagePromptBlock,
|
30
|
+
],
|
31
|
+
"block_type",
|
32
|
+
]
|
33
|
+
|
34
|
+
ArrayVellumValue.model_rebuild()
|
35
|
+
|
36
|
+
CompiledValuePromptBlock.model_rebuild()
|
vellum/utils/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import datetime
|
2
|
+
import itertools
|
3
|
+
import json
|
4
|
+
import random
|
5
|
+
import re
|
6
|
+
from typing import Any, Callable, Dict, Union
|
7
|
+
|
8
|
+
import dateutil
|
9
|
+
import pydash
|
10
|
+
import pytz
|
11
|
+
import yaml
|
12
|
+
|
13
|
+
from vellum.utils.templating.custom_filters import is_valid_json_string
|
14
|
+
|
15
|
+
DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
|
16
|
+
"datetime": datetime,
|
17
|
+
"dateutil": dateutil,
|
18
|
+
"itertools": itertools,
|
19
|
+
"json": json,
|
20
|
+
"pydash": pydash,
|
21
|
+
"pytz": pytz,
|
22
|
+
"random": random,
|
23
|
+
"re": re,
|
24
|
+
"yaml": yaml,
|
25
|
+
}
|
26
|
+
DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
|
27
|
+
"is_valid_json_string": is_valid_json_string,
|
28
|
+
}
|
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
|
3
3
|
|
4
4
|
from jinja2.sandbox import SandboxedEnvironment
|
5
5
|
|
6
|
-
from vellum.
|
6
|
+
from vellum.utils.templating.exceptions import JinjaTemplateError
|
7
7
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
8
8
|
|
9
9
|
|
@@ -344,8 +344,8 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
344
344
|
all_inputs = {}
|
345
345
|
for key, value in inputs.items():
|
346
346
|
path_parts = key.split(".")
|
347
|
-
|
348
|
-
inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:],
|
347
|
+
node_attribute_descriptor = getattr(self.__class__, path_parts[0])
|
348
|
+
inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:], node_attribute_descriptor)
|
349
349
|
all_inputs[inputs_key] = value
|
350
350
|
|
351
351
|
self._inputs = MappingProxyType(all_inputs)
|
@@ -355,7 +355,3 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
355
355
|
|
356
356
|
def __repr__(self) -> str:
|
357
357
|
return str(self.__class__)
|
358
|
-
|
359
|
-
|
360
|
-
class MyNode2(BaseNode):
|
361
|
-
pass
|
@@ -1,12 +1,14 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Generic, Iterator, Optional, Set, Type, TypeVar
|
1
|
+
from typing import TYPE_CHECKING, ClassVar, Generic, Iterator, Optional, Set, Type, TypeVar, Union
|
2
2
|
|
3
3
|
from vellum.workflows.context import execution_context, get_parent_context
|
4
4
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
5
5
|
from vellum.workflows.exceptions import NodeException
|
6
|
-
from vellum.workflows.
|
6
|
+
from vellum.workflows.inputs.base import BaseInputs
|
7
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
7
8
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
8
9
|
from vellum.workflows.state.base import BaseState
|
9
10
|
from vellum.workflows.state.context import WorkflowContext
|
11
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
10
12
|
from vellum.workflows.types.generics import StateType, WorkflowInputsType
|
11
13
|
|
12
14
|
if TYPE_CHECKING:
|
@@ -15,7 +17,7 @@ if TYPE_CHECKING:
|
|
15
17
|
InnerStateType = TypeVar("InnerStateType", bound=BaseState)
|
16
18
|
|
17
19
|
|
18
|
-
class InlineSubworkflowNode(
|
20
|
+
class InlineSubworkflowNode(BaseNode[StateType], Generic[StateType, WorkflowInputsType, InnerStateType]):
|
19
21
|
"""
|
20
22
|
Used to execute a Subworkflow defined inline.
|
21
23
|
|
@@ -24,14 +26,13 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
|
|
24
26
|
"""
|
25
27
|
|
26
28
|
subworkflow: Type["BaseWorkflow[WorkflowInputsType, InnerStateType]"]
|
29
|
+
subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs]] = {}
|
27
30
|
|
28
31
|
def run(self) -> Iterator[BaseOutput]:
|
29
32
|
with execution_context(parent_context=get_parent_context() or self._context.parent_context):
|
30
33
|
subworkflow = self.subworkflow(
|
31
34
|
parent_state=self.state,
|
32
|
-
context=WorkflowContext(
|
33
|
-
_vellum_client=self._context._vellum_client,
|
34
|
-
),
|
35
|
+
context=WorkflowContext(vellum_client=self._context.vellum_client),
|
35
36
|
)
|
36
37
|
subworkflow_stream = subworkflow.stream(
|
37
38
|
inputs=self._compile_subworkflow_inputs(),
|
@@ -68,4 +69,9 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
|
|
68
69
|
|
69
70
|
def _compile_subworkflow_inputs(self) -> WorkflowInputsType:
|
70
71
|
inputs_class = self.subworkflow.get_inputs_class()
|
71
|
-
|
72
|
+
if isinstance(self.subworkflow_inputs, dict):
|
73
|
+
return inputs_class(**self.subworkflow_inputs)
|
74
|
+
elif isinstance(self.subworkflow_inputs, inputs_class):
|
75
|
+
return self.subworkflow_inputs
|
76
|
+
else:
|
77
|
+
raise ValueError(f"Invalid subworkflow inputs type: {type(self.subworkflow_inputs)}")
|
File without changes
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.inputs.base import BaseInputs
|
4
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
5
|
+
from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
|
6
|
+
from vellum.workflows.outputs.base import BaseOutput
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
9
|
+
|
10
|
+
|
11
|
+
class Inputs(BaseInputs):
|
12
|
+
foo: str
|
13
|
+
|
14
|
+
|
15
|
+
class MyInnerNode(BaseNode):
|
16
|
+
class Outputs(BaseNode.Outputs):
|
17
|
+
out = Inputs.foo
|
18
|
+
|
19
|
+
|
20
|
+
class MySubworkflow(BaseWorkflow[Inputs, BaseState]):
|
21
|
+
graph = MyInnerNode
|
22
|
+
|
23
|
+
class Outputs(BaseWorkflow.Outputs):
|
24
|
+
out = MyInnerNode.Outputs.out
|
25
|
+
|
26
|
+
|
27
|
+
@pytest.mark.parametrize("inputs", [{"foo": "bar"}, Inputs(foo="bar")])
|
28
|
+
def test_inline_subworkflow_node__inputs(inputs):
|
29
|
+
# GIVEN a node setup with subworkflow inputs
|
30
|
+
class MyNode(InlineSubworkflowNode):
|
31
|
+
subworkflow = MySubworkflow
|
32
|
+
subworkflow_inputs = inputs
|
33
|
+
|
34
|
+
# WHEN the node is run
|
35
|
+
node = MyNode()
|
36
|
+
events = list(node.run())
|
37
|
+
|
38
|
+
# THEN the output is as expected
|
39
|
+
assert events == [
|
40
|
+
BaseOutput(name="out", value="bar"),
|
41
|
+
]
|
@@ -108,7 +108,7 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
|
|
108
108
|
self._run_subworkflow(item=item, index=index)
|
109
109
|
|
110
110
|
def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
|
111
|
-
context = WorkflowContext(
|
111
|
+
context = WorkflowContext(vellum_client=self._context.vellum_client)
|
112
112
|
subworkflow = self.subworkflow(parent_state=self.state, context=context)
|
113
113
|
events = subworkflow.stream(
|
114
114
|
inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items),
|
@@ -1,42 +1,17 @@
|
|
1
|
-
import datetime
|
2
|
-
import itertools
|
3
1
|
import json
|
4
|
-
import random
|
5
|
-
import re
|
6
2
|
from typing import Any, Callable, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, Union, get_args
|
7
3
|
|
8
|
-
import
|
9
|
-
import
|
10
|
-
import
|
11
|
-
import yaml
|
12
|
-
|
4
|
+
from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS
|
5
|
+
from vellum.utils.templating.exceptions import JinjaTemplateError
|
6
|
+
from vellum.utils.templating.render import render_sandboxed_jinja_template
|
13
7
|
from vellum.workflows.errors import WorkflowErrorCode
|
14
8
|
from vellum.workflows.exceptions import NodeException
|
15
9
|
from vellum.workflows.nodes.bases import BaseNode
|
16
10
|
from vellum.workflows.nodes.bases.base import BaseNodeMeta
|
17
|
-
from vellum.workflows.
|
18
|
-
from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
|
19
|
-
from vellum.workflows.nodes.core.templating_node.render import render_sandboxed_jinja_template
|
20
|
-
from vellum.workflows.types.core import EntityInputsInterface
|
11
|
+
from vellum.workflows.types.core import EntityInputsInterface, Json
|
21
12
|
from vellum.workflows.types.generics import StateType
|
22
13
|
from vellum.workflows.types.utils import get_original_base
|
23
14
|
|
24
|
-
_DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
|
25
|
-
"datetime": datetime,
|
26
|
-
"dateutil": dateutil,
|
27
|
-
"itertools": itertools,
|
28
|
-
"json": json,
|
29
|
-
"pydash": pydash,
|
30
|
-
"pytz": pytz,
|
31
|
-
"random": random,
|
32
|
-
"re": re,
|
33
|
-
"yaml": yaml,
|
34
|
-
}
|
35
|
-
|
36
|
-
_DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
|
37
|
-
"is_valid_json_string": is_valid_json_string,
|
38
|
-
}
|
39
|
-
|
40
15
|
_OutputType = TypeVar("_OutputType")
|
41
16
|
|
42
17
|
|
@@ -78,8 +53,8 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
|
|
78
53
|
# The inputs to render the template with.
|
79
54
|
inputs: ClassVar[EntityInputsInterface]
|
80
55
|
|
81
|
-
jinja_globals: Dict[str, Any] =
|
82
|
-
jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] =
|
56
|
+
jinja_globals: Dict[str, Any] = DEFAULT_JINJA_GLOBALS
|
57
|
+
jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = DEFAULT_JINJA_CUSTOM_FILTERS
|
83
58
|
|
84
59
|
class Outputs(BaseNode.Outputs):
|
85
60
|
"""
|
@@ -113,6 +88,12 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
|
|
113
88
|
if output_type is bool:
|
114
89
|
return bool(rendered_template)
|
115
90
|
|
91
|
+
if output_type is Json:
|
92
|
+
try:
|
93
|
+
return json.loads(rendered_template)
|
94
|
+
except json.JSONDecodeError:
|
95
|
+
raise ValueError("Invalid JSON format for rendered_template")
|
96
|
+
|
116
97
|
raise ValueError(f"Unsupported output type: {output_type}")
|
117
98
|
|
118
99
|
def run(self) -> Outputs:
|