vellum-ai 0.12.7__py3-none-any.whl → 0.12.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/evaluations/resources.py +2 -2
- vellum/prompts/__init__.py +0 -0
- vellum/prompts/blocks/__init__.py +0 -0
- vellum/prompts/blocks/compilation.py +190 -0
- vellum/prompts/blocks/exceptions.py +2 -0
- vellum/prompts/blocks/tests/__init__.py +0 -0
- vellum/prompts/blocks/tests/test_compilation.py +110 -0
- vellum/prompts/blocks/types.py +36 -0
- vellum/utils/__init__.py +0 -0
- vellum/utils/templating/__init__.py +0 -0
- vellum/utils/templating/constants.py +28 -0
- vellum/{workflows/nodes/core/templating_node → utils/templating}/render.py +1 -1
- vellum/workflows/nodes/bases/__init__.py +0 -2
- vellum/workflows/nodes/bases/base.py +2 -6
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +13 -7
- vellum/workflows/nodes/core/inline_subworkflow_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +41 -0
- vellum/workflows/nodes/core/map_node/node.py +1 -1
- vellum/workflows/nodes/core/templating_node/node.py +12 -31
- vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +66 -0
- vellum/workflows/nodes/core/try_node/node.py +1 -3
- vellum/workflows/nodes/core/try_node/tests/test_node.py +1 -1
- vellum/workflows/nodes/displayable/bases/api_node/node.py +2 -2
- vellum/workflows/nodes/displayable/bases/search_node.py +5 -2
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +4 -2
- vellum/workflows/nodes/experimental/README.md +6 -0
- vellum/workflows/nodes/experimental/__init__.py +0 -0
- vellum/workflows/nodes/experimental/openai_chat_completion_node/__init__.py +5 -0
- vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +260 -0
- vellum/workflows/sandbox.py +7 -5
- vellum/workflows/state/context.py +5 -4
- vellum/workflows/tests/test_sandbox.py +2 -2
- vellum/workflows/utils/tests/test_vellum_variables.py +3 -0
- vellum/workflows/utils/vellum_variables.py +5 -4
- vellum/workflows/workflows/base.py +5 -2
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/METADATA +2 -1
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/RECORD +48 -33
- vellum_cli/tests/test_push.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +6 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +207 -0
- vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +0 -5
- vellum/workflows/nodes/bases/base_subworkflow_node/node.py +0 -10
- /vellum/{workflows/nodes/core/templating_node → utils/templating}/custom_filters.py +0 -0
- /vellum/{workflows/nodes/core/templating_node → utils/templating}/exceptions.py +0 -0
- /vellum/{evaluations/utils → utils}/typing.py +0 -0
- /vellum/{evaluations/utils → utils}/uuid.py +0 -0
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/LICENSE +0 -0
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/WHEEL +0 -0
- {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.12.
|
21
|
+
"X-Fern-SDK-Version": "0.12.9",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
vellum/evaluations/resources.py
CHANGED
@@ -12,8 +12,6 @@ from vellum.evaluations.constants import DEFAULT_MAX_POLLING_DURATION_MS, DEFAUL
|
|
12
12
|
from vellum.evaluations.exceptions import TestSuiteRunResultsException
|
13
13
|
from vellum.evaluations.utils.env import get_api_key
|
14
14
|
from vellum.evaluations.utils.paginator import PaginatedResults, get_all_results
|
15
|
-
from vellum.evaluations.utils.typing import cast_not_optional
|
16
|
-
from vellum.evaluations.utils.uuid import is_valid_uuid
|
17
15
|
from vellum.types import (
|
18
16
|
ExternalTestCaseExecutionRequest,
|
19
17
|
NamedTestCaseVariableValueRequest,
|
@@ -24,6 +22,8 @@ from vellum.types import (
|
|
24
22
|
TestSuiteRunMetricOutput,
|
25
23
|
TestSuiteRunState,
|
26
24
|
)
|
25
|
+
from vellum.utils.typing import cast_not_optional
|
26
|
+
from vellum.utils.uuid import is_valid_uuid
|
27
27
|
|
28
28
|
logger = logging.getLogger(__name__)
|
29
29
|
|
File without changes
|
File without changes
|
@@ -0,0 +1,190 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Optional, cast
|
3
|
+
|
4
|
+
from vellum import (
|
5
|
+
ChatMessage,
|
6
|
+
JsonVellumValue,
|
7
|
+
PromptBlock,
|
8
|
+
PromptRequestInput,
|
9
|
+
RichTextPromptBlock,
|
10
|
+
StringVellumValue,
|
11
|
+
VellumVariable,
|
12
|
+
)
|
13
|
+
from vellum.prompts.blocks.exceptions import PromptCompilationError
|
14
|
+
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledPromptBlock, CompiledValuePromptBlock
|
15
|
+
from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS
|
16
|
+
from vellum.utils.templating.render import render_sandboxed_jinja_template
|
17
|
+
from vellum.utils.typing import cast_not_optional
|
18
|
+
|
19
|
+
|
20
|
+
def compile_prompt_blocks(
|
21
|
+
blocks: list[PromptBlock],
|
22
|
+
inputs: list[PromptRequestInput],
|
23
|
+
input_variables: list[VellumVariable],
|
24
|
+
) -> list[CompiledPromptBlock]:
|
25
|
+
"""Compiles a list of Prompt Blocks, performing all variable substitutions and Jinja templating needed."""
|
26
|
+
|
27
|
+
sanitized_inputs = _sanitize_inputs(inputs)
|
28
|
+
|
29
|
+
compiled_blocks: list[CompiledPromptBlock] = []
|
30
|
+
for block in blocks:
|
31
|
+
if block.state == "DISABLED":
|
32
|
+
continue
|
33
|
+
|
34
|
+
if block.block_type == "CHAT_MESSAGE":
|
35
|
+
chat_role = cast_not_optional(block.chat_role)
|
36
|
+
inner_blocks = cast_not_optional(block.blocks)
|
37
|
+
unterminated = block.chat_message_unterminated or False
|
38
|
+
|
39
|
+
inner_prompt_blocks = compile_prompt_blocks(
|
40
|
+
inner_blocks,
|
41
|
+
sanitized_inputs,
|
42
|
+
input_variables,
|
43
|
+
)
|
44
|
+
if not inner_prompt_blocks:
|
45
|
+
continue
|
46
|
+
|
47
|
+
compiled_blocks.append(
|
48
|
+
CompiledChatMessagePromptBlock(
|
49
|
+
role=chat_role,
|
50
|
+
unterminated=unterminated,
|
51
|
+
source=block.chat_source,
|
52
|
+
blocks=[inner for inner in inner_prompt_blocks if inner.block_type == "VALUE"],
|
53
|
+
cache_config=block.cache_config,
|
54
|
+
)
|
55
|
+
)
|
56
|
+
|
57
|
+
elif block.block_type == "JINJA":
|
58
|
+
if block.template is None:
|
59
|
+
continue
|
60
|
+
|
61
|
+
rendered_template = render_sandboxed_jinja_template(
|
62
|
+
template=block.template,
|
63
|
+
input_values={input_.key: input_.value for input_ in sanitized_inputs},
|
64
|
+
jinja_custom_filters=DEFAULT_JINJA_CUSTOM_FILTERS,
|
65
|
+
jinja_globals=DEFAULT_JINJA_CUSTOM_FILTERS,
|
66
|
+
)
|
67
|
+
jinja_content = StringVellumValue(value=rendered_template)
|
68
|
+
|
69
|
+
compiled_blocks.append(
|
70
|
+
CompiledValuePromptBlock(
|
71
|
+
content=jinja_content,
|
72
|
+
cache_config=block.cache_config,
|
73
|
+
)
|
74
|
+
)
|
75
|
+
|
76
|
+
elif block.block_type == "VARIABLE":
|
77
|
+
compiled_input: Optional[PromptRequestInput] = next(
|
78
|
+
(input_ for input_ in sanitized_inputs if input_.key == str(block.input_variable)), None
|
79
|
+
)
|
80
|
+
if compiled_input is None:
|
81
|
+
raise PromptCompilationError(f"Input variable '{block.input_variable}' not found")
|
82
|
+
|
83
|
+
if compiled_input.type == "CHAT_HISTORY":
|
84
|
+
history = cast(list[ChatMessage], compiled_input.value)
|
85
|
+
chat_message_blocks = _compile_chat_messages_as_prompt_blocks(history)
|
86
|
+
compiled_blocks.extend(chat_message_blocks)
|
87
|
+
continue
|
88
|
+
|
89
|
+
if compiled_input.type == "STRING":
|
90
|
+
compiled_blocks.append(
|
91
|
+
CompiledValuePromptBlock(
|
92
|
+
content=StringVellumValue(value=compiled_input.value),
|
93
|
+
cache_config=block.cache_config,
|
94
|
+
)
|
95
|
+
)
|
96
|
+
elif compiled_input == "JSON":
|
97
|
+
compiled_blocks.append(
|
98
|
+
CompiledValuePromptBlock(
|
99
|
+
content=JsonVellumValue(value=compiled_input.value),
|
100
|
+
cache_config=block.cache_config,
|
101
|
+
)
|
102
|
+
)
|
103
|
+
elif compiled_input.type == "CHAT_HISTORY":
|
104
|
+
chat_message_blocks = _compile_chat_messages_as_prompt_blocks(compiled_input.value)
|
105
|
+
compiled_blocks.extend(chat_message_blocks)
|
106
|
+
else:
|
107
|
+
raise PromptCompilationError(f"Invalid input type for variable block: {compiled_input.type}")
|
108
|
+
|
109
|
+
elif block.block_type == "RICH_TEXT":
|
110
|
+
value_block = _compile_rich_text_block_as_value_block(block=block, inputs=sanitized_inputs)
|
111
|
+
compiled_blocks.append(value_block)
|
112
|
+
|
113
|
+
elif block.block_type == "FUNCTION_DEFINITION":
|
114
|
+
raise PromptCompilationError("Function definitions shouldn't go through compilation process")
|
115
|
+
else:
|
116
|
+
raise PromptCompilationError(f"Unknown block_type: {block.block_type}")
|
117
|
+
|
118
|
+
return compiled_blocks
|
119
|
+
|
120
|
+
|
121
|
+
def _compile_chat_messages_as_prompt_blocks(chat_messages: list[ChatMessage]) -> list[CompiledChatMessagePromptBlock]:
|
122
|
+
blocks: list[CompiledChatMessagePromptBlock] = []
|
123
|
+
for chat_message in chat_messages:
|
124
|
+
if chat_message.content is None:
|
125
|
+
continue
|
126
|
+
|
127
|
+
chat_message_blocks = (
|
128
|
+
[
|
129
|
+
CompiledValuePromptBlock(
|
130
|
+
content=item,
|
131
|
+
)
|
132
|
+
for item in chat_message.content.value
|
133
|
+
]
|
134
|
+
if chat_message.content.type == "ARRAY"
|
135
|
+
else [
|
136
|
+
CompiledValuePromptBlock(
|
137
|
+
content=chat_message.content,
|
138
|
+
)
|
139
|
+
]
|
140
|
+
)
|
141
|
+
|
142
|
+
blocks.append(
|
143
|
+
CompiledChatMessagePromptBlock(
|
144
|
+
role=chat_message.role,
|
145
|
+
unterminated=False,
|
146
|
+
blocks=chat_message_blocks,
|
147
|
+
source=chat_message.source,
|
148
|
+
)
|
149
|
+
)
|
150
|
+
|
151
|
+
return blocks
|
152
|
+
|
153
|
+
|
154
|
+
def _compile_rich_text_block_as_value_block(
|
155
|
+
block: RichTextPromptBlock,
|
156
|
+
inputs: list[PromptRequestInput],
|
157
|
+
) -> CompiledValuePromptBlock:
|
158
|
+
value: str = ""
|
159
|
+
for child_block in block.blocks:
|
160
|
+
if child_block.block_type == "PLAIN_TEXT":
|
161
|
+
value += child_block.text
|
162
|
+
elif child_block.block_type == "VARIABLE":
|
163
|
+
variable = next((input_ for input_ in inputs if input_.key == str(child_block.input_variable)), None)
|
164
|
+
if variable is None:
|
165
|
+
raise PromptCompilationError(f"Input variable '{child_block.input_variable}' not found")
|
166
|
+
elif variable.type == "STRING":
|
167
|
+
value += str(variable.value)
|
168
|
+
elif variable.type == "JSON":
|
169
|
+
value += json.dumps(variable.value, indent=4)
|
170
|
+
else:
|
171
|
+
raise PromptCompilationError(
|
172
|
+
f"Input variable '{child_block.input_variable}' must be of type STRING or JSON"
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
raise PromptCompilationError(f"Invalid child block_type for RICH_TEXT: {child_block.block_type}")
|
176
|
+
|
177
|
+
return CompiledValuePromptBlock(content=StringVellumValue(value=value), cache_config=block.cache_config)
|
178
|
+
|
179
|
+
|
180
|
+
def _sanitize_inputs(inputs: list[PromptRequestInput]) -> list[PromptRequestInput]:
|
181
|
+
sanitized_inputs: list[PromptRequestInput] = []
|
182
|
+
for input_ in inputs:
|
183
|
+
if input_.type == "CHAT_HISTORY" and input_.value is None:
|
184
|
+
sanitized_inputs.append(input_.model_copy(update={"value": cast(list[ChatMessage], [])}))
|
185
|
+
elif input_.type == "STRING" and input_.value is None:
|
186
|
+
sanitized_inputs.append(input_.model_copy(update={"value": ""}))
|
187
|
+
else:
|
188
|
+
sanitized_inputs.append(input_)
|
189
|
+
|
190
|
+
return sanitized_inputs
|
File without changes
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum import (
|
4
|
+
ChatMessagePromptBlock,
|
5
|
+
JinjaPromptBlock,
|
6
|
+
PlainTextPromptBlock,
|
7
|
+
PromptRequestStringInput,
|
8
|
+
RichTextPromptBlock,
|
9
|
+
StringVellumValue,
|
10
|
+
VariablePromptBlock,
|
11
|
+
VellumVariable,
|
12
|
+
)
|
13
|
+
from vellum.prompts.blocks.compilation import compile_prompt_blocks
|
14
|
+
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
|
15
|
+
|
16
|
+
|
17
|
+
@pytest.mark.parametrize(
|
18
|
+
["blocks", "inputs", "input_variables", "expected"],
|
19
|
+
[
|
20
|
+
# Empty
|
21
|
+
([], [], [], []),
|
22
|
+
# Jinja
|
23
|
+
(
|
24
|
+
[JinjaPromptBlock(template="Hello, world!")],
|
25
|
+
[],
|
26
|
+
[],
|
27
|
+
[
|
28
|
+
CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
|
29
|
+
],
|
30
|
+
),
|
31
|
+
(
|
32
|
+
[JinjaPromptBlock(template="Repeat back to me {{ echo }}")],
|
33
|
+
[PromptRequestStringInput(key="echo", value="Hello, world!")],
|
34
|
+
[VellumVariable(id="1", type="STRING", key="echo")],
|
35
|
+
[
|
36
|
+
CompiledValuePromptBlock(content=StringVellumValue(value="Repeat back to me Hello, world!")),
|
37
|
+
],
|
38
|
+
),
|
39
|
+
# Rich Text
|
40
|
+
(
|
41
|
+
[
|
42
|
+
RichTextPromptBlock(
|
43
|
+
blocks=[
|
44
|
+
PlainTextPromptBlock(text="Hello, world!"),
|
45
|
+
]
|
46
|
+
)
|
47
|
+
],
|
48
|
+
[],
|
49
|
+
[],
|
50
|
+
[
|
51
|
+
CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
|
52
|
+
],
|
53
|
+
),
|
54
|
+
(
|
55
|
+
[
|
56
|
+
RichTextPromptBlock(
|
57
|
+
blocks=[
|
58
|
+
PlainTextPromptBlock(text='Repeat back to me "'),
|
59
|
+
VariablePromptBlock(input_variable="echo"),
|
60
|
+
PlainTextPromptBlock(text='".'),
|
61
|
+
]
|
62
|
+
)
|
63
|
+
],
|
64
|
+
[PromptRequestStringInput(key="echo", value="Hello, world!")],
|
65
|
+
[VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
|
66
|
+
[
|
67
|
+
CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".')),
|
68
|
+
],
|
69
|
+
),
|
70
|
+
# Chat Message
|
71
|
+
(
|
72
|
+
[
|
73
|
+
ChatMessagePromptBlock(
|
74
|
+
chat_role="USER",
|
75
|
+
blocks=[
|
76
|
+
RichTextPromptBlock(
|
77
|
+
blocks=[
|
78
|
+
PlainTextPromptBlock(text='Repeat back to me "'),
|
79
|
+
VariablePromptBlock(input_variable="echo"),
|
80
|
+
PlainTextPromptBlock(text='".'),
|
81
|
+
]
|
82
|
+
)
|
83
|
+
],
|
84
|
+
)
|
85
|
+
],
|
86
|
+
[PromptRequestStringInput(key="echo", value="Hello, world!")],
|
87
|
+
[VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
|
88
|
+
[
|
89
|
+
CompiledChatMessagePromptBlock(
|
90
|
+
role="USER",
|
91
|
+
blocks=[
|
92
|
+
CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".'))
|
93
|
+
],
|
94
|
+
),
|
95
|
+
],
|
96
|
+
),
|
97
|
+
],
|
98
|
+
ids=[
|
99
|
+
"empty",
|
100
|
+
"jinja-no-variables",
|
101
|
+
"jinja-with-variables",
|
102
|
+
"rich-text-no-variables",
|
103
|
+
"rich-text-with-variables",
|
104
|
+
"chat-message",
|
105
|
+
],
|
106
|
+
)
|
107
|
+
def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected):
|
108
|
+
actual = compile_prompt_blocks(blocks=blocks, inputs=inputs, input_variables=input_variables)
|
109
|
+
|
110
|
+
assert actual == expected
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Annotated, Literal, Optional, Union
|
4
|
+
|
5
|
+
from vellum import ArrayVellumValue, ChatMessageRole, EphemeralPromptCacheConfig, VellumValue
|
6
|
+
from vellum.client.core import UniversalBaseModel
|
7
|
+
|
8
|
+
|
9
|
+
class BaseCompiledPromptBlock(UniversalBaseModel):
|
10
|
+
cache_config: Optional[EphemeralPromptCacheConfig] = None
|
11
|
+
|
12
|
+
|
13
|
+
class CompiledValuePromptBlock(BaseCompiledPromptBlock):
|
14
|
+
block_type: Literal["VALUE"] = "VALUE"
|
15
|
+
content: VellumValue
|
16
|
+
|
17
|
+
|
18
|
+
class CompiledChatMessagePromptBlock(BaseCompiledPromptBlock):
|
19
|
+
block_type: Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
|
20
|
+
role: ChatMessageRole = "ASSISTANT"
|
21
|
+
unterminated: bool = False
|
22
|
+
blocks: list[CompiledValuePromptBlock] = []
|
23
|
+
source: Optional[str] = None
|
24
|
+
|
25
|
+
|
26
|
+
CompiledPromptBlock = Annotated[
|
27
|
+
Union[
|
28
|
+
CompiledValuePromptBlock,
|
29
|
+
CompiledChatMessagePromptBlock,
|
30
|
+
],
|
31
|
+
"block_type",
|
32
|
+
]
|
33
|
+
|
34
|
+
ArrayVellumValue.model_rebuild()
|
35
|
+
|
36
|
+
CompiledValuePromptBlock.model_rebuild()
|
vellum/utils/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import datetime
|
2
|
+
import itertools
|
3
|
+
import json
|
4
|
+
import random
|
5
|
+
import re
|
6
|
+
from typing import Any, Callable, Dict, Union
|
7
|
+
|
8
|
+
import dateutil
|
9
|
+
import pydash
|
10
|
+
import pytz
|
11
|
+
import yaml
|
12
|
+
|
13
|
+
from vellum.utils.templating.custom_filters import is_valid_json_string
|
14
|
+
|
15
|
+
DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
|
16
|
+
"datetime": datetime,
|
17
|
+
"dateutil": dateutil,
|
18
|
+
"itertools": itertools,
|
19
|
+
"json": json,
|
20
|
+
"pydash": pydash,
|
21
|
+
"pytz": pytz,
|
22
|
+
"random": random,
|
23
|
+
"re": re,
|
24
|
+
"yaml": yaml,
|
25
|
+
}
|
26
|
+
DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
|
27
|
+
"is_valid_json_string": is_valid_json_string,
|
28
|
+
}
|
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
|
3
3
|
|
4
4
|
from jinja2.sandbox import SandboxedEnvironment
|
5
5
|
|
6
|
-
from vellum.
|
6
|
+
from vellum.utils.templating.exceptions import JinjaTemplateError
|
7
7
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
8
8
|
|
9
9
|
|
@@ -344,8 +344,8 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
344
344
|
all_inputs = {}
|
345
345
|
for key, value in inputs.items():
|
346
346
|
path_parts = key.split(".")
|
347
|
-
|
348
|
-
inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:],
|
347
|
+
node_attribute_descriptor = getattr(self.__class__, path_parts[0])
|
348
|
+
inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:], node_attribute_descriptor)
|
349
349
|
all_inputs[inputs_key] = value
|
350
350
|
|
351
351
|
self._inputs = MappingProxyType(all_inputs)
|
@@ -355,7 +355,3 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
355
355
|
|
356
356
|
def __repr__(self) -> str:
|
357
357
|
return str(self.__class__)
|
358
|
-
|
359
|
-
|
360
|
-
class MyNode2(BaseNode):
|
361
|
-
pass
|
@@ -1,12 +1,14 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Generic, Iterator, Optional, Set, Type, TypeVar
|
1
|
+
from typing import TYPE_CHECKING, ClassVar, Generic, Iterator, Optional, Set, Type, TypeVar, Union
|
2
2
|
|
3
3
|
from vellum.workflows.context import execution_context, get_parent_context
|
4
4
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
5
5
|
from vellum.workflows.exceptions import NodeException
|
6
|
-
from vellum.workflows.
|
6
|
+
from vellum.workflows.inputs.base import BaseInputs
|
7
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
7
8
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
8
9
|
from vellum.workflows.state.base import BaseState
|
9
10
|
from vellum.workflows.state.context import WorkflowContext
|
11
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
10
12
|
from vellum.workflows.types.generics import StateType, WorkflowInputsType
|
11
13
|
|
12
14
|
if TYPE_CHECKING:
|
@@ -15,7 +17,7 @@ if TYPE_CHECKING:
|
|
15
17
|
InnerStateType = TypeVar("InnerStateType", bound=BaseState)
|
16
18
|
|
17
19
|
|
18
|
-
class InlineSubworkflowNode(
|
20
|
+
class InlineSubworkflowNode(BaseNode[StateType], Generic[StateType, WorkflowInputsType, InnerStateType]):
|
19
21
|
"""
|
20
22
|
Used to execute a Subworkflow defined inline.
|
21
23
|
|
@@ -24,14 +26,13 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
|
|
24
26
|
"""
|
25
27
|
|
26
28
|
subworkflow: Type["BaseWorkflow[WorkflowInputsType, InnerStateType]"]
|
29
|
+
subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs]] = {}
|
27
30
|
|
28
31
|
def run(self) -> Iterator[BaseOutput]:
|
29
32
|
with execution_context(parent_context=get_parent_context() or self._context.parent_context):
|
30
33
|
subworkflow = self.subworkflow(
|
31
34
|
parent_state=self.state,
|
32
|
-
context=WorkflowContext(
|
33
|
-
_vellum_client=self._context._vellum_client,
|
34
|
-
),
|
35
|
+
context=WorkflowContext(vellum_client=self._context.vellum_client),
|
35
36
|
)
|
36
37
|
subworkflow_stream = subworkflow.stream(
|
37
38
|
inputs=self._compile_subworkflow_inputs(),
|
@@ -68,4 +69,9 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
|
|
68
69
|
|
69
70
|
def _compile_subworkflow_inputs(self) -> WorkflowInputsType:
|
70
71
|
inputs_class = self.subworkflow.get_inputs_class()
|
71
|
-
|
72
|
+
if isinstance(self.subworkflow_inputs, dict):
|
73
|
+
return inputs_class(**self.subworkflow_inputs)
|
74
|
+
elif isinstance(self.subworkflow_inputs, inputs_class):
|
75
|
+
return self.subworkflow_inputs
|
76
|
+
else:
|
77
|
+
raise ValueError(f"Invalid subworkflow inputs type: {type(self.subworkflow_inputs)}")
|
File without changes
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.inputs.base import BaseInputs
|
4
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
5
|
+
from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
|
6
|
+
from vellum.workflows.outputs.base import BaseOutput
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
9
|
+
|
10
|
+
|
11
|
+
class Inputs(BaseInputs):
|
12
|
+
foo: str
|
13
|
+
|
14
|
+
|
15
|
+
class MyInnerNode(BaseNode):
|
16
|
+
class Outputs(BaseNode.Outputs):
|
17
|
+
out = Inputs.foo
|
18
|
+
|
19
|
+
|
20
|
+
class MySubworkflow(BaseWorkflow[Inputs, BaseState]):
|
21
|
+
graph = MyInnerNode
|
22
|
+
|
23
|
+
class Outputs(BaseWorkflow.Outputs):
|
24
|
+
out = MyInnerNode.Outputs.out
|
25
|
+
|
26
|
+
|
27
|
+
@pytest.mark.parametrize("inputs", [{"foo": "bar"}, Inputs(foo="bar")])
|
28
|
+
def test_inline_subworkflow_node__inputs(inputs):
|
29
|
+
# GIVEN a node setup with subworkflow inputs
|
30
|
+
class MyNode(InlineSubworkflowNode):
|
31
|
+
subworkflow = MySubworkflow
|
32
|
+
subworkflow_inputs = inputs
|
33
|
+
|
34
|
+
# WHEN the node is run
|
35
|
+
node = MyNode()
|
36
|
+
events = list(node.run())
|
37
|
+
|
38
|
+
# THEN the output is as expected
|
39
|
+
assert events == [
|
40
|
+
BaseOutput(name="out", value="bar"),
|
41
|
+
]
|
@@ -108,7 +108,7 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
|
|
108
108
|
self._run_subworkflow(item=item, index=index)
|
109
109
|
|
110
110
|
def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
|
111
|
-
context = WorkflowContext(
|
111
|
+
context = WorkflowContext(vellum_client=self._context.vellum_client)
|
112
112
|
subworkflow = self.subworkflow(parent_state=self.state, context=context)
|
113
113
|
events = subworkflow.stream(
|
114
114
|
inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items),
|
@@ -1,42 +1,17 @@
|
|
1
|
-
import datetime
|
2
|
-
import itertools
|
3
1
|
import json
|
4
|
-
import random
|
5
|
-
import re
|
6
2
|
from typing import Any, Callable, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, Union, get_args
|
7
3
|
|
8
|
-
import
|
9
|
-
import
|
10
|
-
import
|
11
|
-
import yaml
|
12
|
-
|
4
|
+
from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS
|
5
|
+
from vellum.utils.templating.exceptions import JinjaTemplateError
|
6
|
+
from vellum.utils.templating.render import render_sandboxed_jinja_template
|
13
7
|
from vellum.workflows.errors import WorkflowErrorCode
|
14
8
|
from vellum.workflows.exceptions import NodeException
|
15
9
|
from vellum.workflows.nodes.bases import BaseNode
|
16
10
|
from vellum.workflows.nodes.bases.base import BaseNodeMeta
|
17
|
-
from vellum.workflows.
|
18
|
-
from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
|
19
|
-
from vellum.workflows.nodes.core.templating_node.render import render_sandboxed_jinja_template
|
20
|
-
from vellum.workflows.types.core import EntityInputsInterface
|
11
|
+
from vellum.workflows.types.core import EntityInputsInterface, Json
|
21
12
|
from vellum.workflows.types.generics import StateType
|
22
13
|
from vellum.workflows.types.utils import get_original_base
|
23
14
|
|
24
|
-
_DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
|
25
|
-
"datetime": datetime,
|
26
|
-
"dateutil": dateutil,
|
27
|
-
"itertools": itertools,
|
28
|
-
"json": json,
|
29
|
-
"pydash": pydash,
|
30
|
-
"pytz": pytz,
|
31
|
-
"random": random,
|
32
|
-
"re": re,
|
33
|
-
"yaml": yaml,
|
34
|
-
}
|
35
|
-
|
36
|
-
_DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
|
37
|
-
"is_valid_json_string": is_valid_json_string,
|
38
|
-
}
|
39
|
-
|
40
15
|
_OutputType = TypeVar("_OutputType")
|
41
16
|
|
42
17
|
|
@@ -78,8 +53,8 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
|
|
78
53
|
# The inputs to render the template with.
|
79
54
|
inputs: ClassVar[EntityInputsInterface]
|
80
55
|
|
81
|
-
jinja_globals: Dict[str, Any] =
|
82
|
-
jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] =
|
56
|
+
jinja_globals: Dict[str, Any] = DEFAULT_JINJA_GLOBALS
|
57
|
+
jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = DEFAULT_JINJA_CUSTOM_FILTERS
|
83
58
|
|
84
59
|
class Outputs(BaseNode.Outputs):
|
85
60
|
"""
|
@@ -113,6 +88,12 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
|
|
113
88
|
if output_type is bool:
|
114
89
|
return bool(rendered_template)
|
115
90
|
|
91
|
+
if output_type is Json:
|
92
|
+
try:
|
93
|
+
return json.loads(rendered_template)
|
94
|
+
except json.JSONDecodeError:
|
95
|
+
raise ValueError("Invalid JSON format for rendered_template")
|
96
|
+
|
116
97
|
raise ValueError(f"Unsupported output type: {output_type}")
|
117
98
|
|
118
99
|
def run(self) -> Outputs:
|