vellum-ai 0.12.7__py3-none-any.whl → 0.12.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/evaluations/resources.py +2 -2
  3. vellum/prompts/__init__.py +0 -0
  4. vellum/prompts/blocks/__init__.py +0 -0
  5. vellum/prompts/blocks/compilation.py +190 -0
  6. vellum/prompts/blocks/exceptions.py +2 -0
  7. vellum/prompts/blocks/tests/__init__.py +0 -0
  8. vellum/prompts/blocks/tests/test_compilation.py +110 -0
  9. vellum/prompts/blocks/types.py +36 -0
  10. vellum/utils/__init__.py +0 -0
  11. vellum/utils/templating/__init__.py +0 -0
  12. vellum/utils/templating/constants.py +28 -0
  13. vellum/{workflows/nodes/core/templating_node → utils/templating}/render.py +1 -1
  14. vellum/workflows/nodes/bases/__init__.py +0 -2
  15. vellum/workflows/nodes/bases/base.py +2 -6
  16. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +13 -7
  17. vellum/workflows/nodes/core/inline_subworkflow_node/tests/__init__.py +0 -0
  18. vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +41 -0
  19. vellum/workflows/nodes/core/map_node/node.py +1 -1
  20. vellum/workflows/nodes/core/templating_node/node.py +12 -31
  21. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +66 -0
  22. vellum/workflows/nodes/core/try_node/node.py +1 -3
  23. vellum/workflows/nodes/core/try_node/tests/test_node.py +1 -1
  24. vellum/workflows/nodes/displayable/bases/api_node/node.py +2 -2
  25. vellum/workflows/nodes/displayable/bases/search_node.py +5 -2
  26. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +4 -2
  27. vellum/workflows/nodes/experimental/README.md +6 -0
  28. vellum/workflows/nodes/experimental/__init__.py +0 -0
  29. vellum/workflows/nodes/experimental/openai_chat_completion_node/__init__.py +5 -0
  30. vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +260 -0
  31. vellum/workflows/sandbox.py +7 -5
  32. vellum/workflows/state/context.py +5 -4
  33. vellum/workflows/tests/test_sandbox.py +2 -2
  34. vellum/workflows/utils/tests/test_vellum_variables.py +3 -0
  35. vellum/workflows/utils/vellum_variables.py +5 -4
  36. vellum/workflows/workflows/base.py +5 -2
  37. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/METADATA +2 -1
  38. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/RECORD +48 -33
  39. vellum_cli/tests/test_push.py +1 -1
  40. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +6 -1
  41. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +207 -0
  42. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +0 -5
  43. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +0 -10
  44. /vellum/{workflows/nodes/core/templating_node → utils/templating}/custom_filters.py +0 -0
  45. /vellum/{workflows/nodes/core/templating_node → utils/templating}/exceptions.py +0 -0
  46. /vellum/{evaluations/utils → utils}/typing.py +0 -0
  47. /vellum/{evaluations/utils → utils}/uuid.py +0 -0
  48. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/LICENSE +0 -0
  49. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/WHEEL +0 -0
  50. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.12.7",
21
+ "X-Fern-SDK-Version": "0.12.9",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -12,8 +12,6 @@ from vellum.evaluations.constants import DEFAULT_MAX_POLLING_DURATION_MS, DEFAUL
12
12
  from vellum.evaluations.exceptions import TestSuiteRunResultsException
13
13
  from vellum.evaluations.utils.env import get_api_key
14
14
  from vellum.evaluations.utils.paginator import PaginatedResults, get_all_results
15
- from vellum.evaluations.utils.typing import cast_not_optional
16
- from vellum.evaluations.utils.uuid import is_valid_uuid
17
15
  from vellum.types import (
18
16
  ExternalTestCaseExecutionRequest,
19
17
  NamedTestCaseVariableValueRequest,
@@ -24,6 +22,8 @@ from vellum.types import (
24
22
  TestSuiteRunMetricOutput,
25
23
  TestSuiteRunState,
26
24
  )
25
+ from vellum.utils.typing import cast_not_optional
26
+ from vellum.utils.uuid import is_valid_uuid
27
27
 
28
28
  logger = logging.getLogger(__name__)
29
29
 
File without changes
File without changes
@@ -0,0 +1,190 @@
1
+ import json
2
+ from typing import Optional, cast
3
+
4
+ from vellum import (
5
+ ChatMessage,
6
+ JsonVellumValue,
7
+ PromptBlock,
8
+ PromptRequestInput,
9
+ RichTextPromptBlock,
10
+ StringVellumValue,
11
+ VellumVariable,
12
+ )
13
+ from vellum.prompts.blocks.exceptions import PromptCompilationError
14
+ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledPromptBlock, CompiledValuePromptBlock
15
+ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS
16
+ from vellum.utils.templating.render import render_sandboxed_jinja_template
17
+ from vellum.utils.typing import cast_not_optional
18
+
19
+
20
+ def compile_prompt_blocks(
21
+ blocks: list[PromptBlock],
22
+ inputs: list[PromptRequestInput],
23
+ input_variables: list[VellumVariable],
24
+ ) -> list[CompiledPromptBlock]:
25
+ """Compiles a list of Prompt Blocks, performing all variable substitutions and Jinja templating needed."""
26
+
27
+ sanitized_inputs = _sanitize_inputs(inputs)
28
+
29
+ compiled_blocks: list[CompiledPromptBlock] = []
30
+ for block in blocks:
31
+ if block.state == "DISABLED":
32
+ continue
33
+
34
+ if block.block_type == "CHAT_MESSAGE":
35
+ chat_role = cast_not_optional(block.chat_role)
36
+ inner_blocks = cast_not_optional(block.blocks)
37
+ unterminated = block.chat_message_unterminated or False
38
+
39
+ inner_prompt_blocks = compile_prompt_blocks(
40
+ inner_blocks,
41
+ sanitized_inputs,
42
+ input_variables,
43
+ )
44
+ if not inner_prompt_blocks:
45
+ continue
46
+
47
+ compiled_blocks.append(
48
+ CompiledChatMessagePromptBlock(
49
+ role=chat_role,
50
+ unterminated=unterminated,
51
+ source=block.chat_source,
52
+ blocks=[inner for inner in inner_prompt_blocks if inner.block_type == "VALUE"],
53
+ cache_config=block.cache_config,
54
+ )
55
+ )
56
+
57
+ elif block.block_type == "JINJA":
58
+ if block.template is None:
59
+ continue
60
+
61
+ rendered_template = render_sandboxed_jinja_template(
62
+ template=block.template,
63
+ input_values={input_.key: input_.value for input_ in sanitized_inputs},
64
+ jinja_custom_filters=DEFAULT_JINJA_CUSTOM_FILTERS,
65
+ jinja_globals=DEFAULT_JINJA_CUSTOM_FILTERS,
66
+ )
67
+ jinja_content = StringVellumValue(value=rendered_template)
68
+
69
+ compiled_blocks.append(
70
+ CompiledValuePromptBlock(
71
+ content=jinja_content,
72
+ cache_config=block.cache_config,
73
+ )
74
+ )
75
+
76
+ elif block.block_type == "VARIABLE":
77
+ compiled_input: Optional[PromptRequestInput] = next(
78
+ (input_ for input_ in sanitized_inputs if input_.key == str(block.input_variable)), None
79
+ )
80
+ if compiled_input is None:
81
+ raise PromptCompilationError(f"Input variable '{block.input_variable}' not found")
82
+
83
+ if compiled_input.type == "CHAT_HISTORY":
84
+ history = cast(list[ChatMessage], compiled_input.value)
85
+ chat_message_blocks = _compile_chat_messages_as_prompt_blocks(history)
86
+ compiled_blocks.extend(chat_message_blocks)
87
+ continue
88
+
89
+ if compiled_input.type == "STRING":
90
+ compiled_blocks.append(
91
+ CompiledValuePromptBlock(
92
+ content=StringVellumValue(value=compiled_input.value),
93
+ cache_config=block.cache_config,
94
+ )
95
+ )
96
+ elif compiled_input == "JSON":
97
+ compiled_blocks.append(
98
+ CompiledValuePromptBlock(
99
+ content=JsonVellumValue(value=compiled_input.value),
100
+ cache_config=block.cache_config,
101
+ )
102
+ )
103
+ elif compiled_input.type == "CHAT_HISTORY":
104
+ chat_message_blocks = _compile_chat_messages_as_prompt_blocks(compiled_input.value)
105
+ compiled_blocks.extend(chat_message_blocks)
106
+ else:
107
+ raise PromptCompilationError(f"Invalid input type for variable block: {compiled_input.type}")
108
+
109
+ elif block.block_type == "RICH_TEXT":
110
+ value_block = _compile_rich_text_block_as_value_block(block=block, inputs=sanitized_inputs)
111
+ compiled_blocks.append(value_block)
112
+
113
+ elif block.block_type == "FUNCTION_DEFINITION":
114
+ raise PromptCompilationError("Function definitions shouldn't go through compilation process")
115
+ else:
116
+ raise PromptCompilationError(f"Unknown block_type: {block.block_type}")
117
+
118
+ return compiled_blocks
119
+
120
+
121
+ def _compile_chat_messages_as_prompt_blocks(chat_messages: list[ChatMessage]) -> list[CompiledChatMessagePromptBlock]:
122
+ blocks: list[CompiledChatMessagePromptBlock] = []
123
+ for chat_message in chat_messages:
124
+ if chat_message.content is None:
125
+ continue
126
+
127
+ chat_message_blocks = (
128
+ [
129
+ CompiledValuePromptBlock(
130
+ content=item,
131
+ )
132
+ for item in chat_message.content.value
133
+ ]
134
+ if chat_message.content.type == "ARRAY"
135
+ else [
136
+ CompiledValuePromptBlock(
137
+ content=chat_message.content,
138
+ )
139
+ ]
140
+ )
141
+
142
+ blocks.append(
143
+ CompiledChatMessagePromptBlock(
144
+ role=chat_message.role,
145
+ unterminated=False,
146
+ blocks=chat_message_blocks,
147
+ source=chat_message.source,
148
+ )
149
+ )
150
+
151
+ return blocks
152
+
153
+
154
+ def _compile_rich_text_block_as_value_block(
155
+ block: RichTextPromptBlock,
156
+ inputs: list[PromptRequestInput],
157
+ ) -> CompiledValuePromptBlock:
158
+ value: str = ""
159
+ for child_block in block.blocks:
160
+ if child_block.block_type == "PLAIN_TEXT":
161
+ value += child_block.text
162
+ elif child_block.block_type == "VARIABLE":
163
+ variable = next((input_ for input_ in inputs if input_.key == str(child_block.input_variable)), None)
164
+ if variable is None:
165
+ raise PromptCompilationError(f"Input variable '{child_block.input_variable}' not found")
166
+ elif variable.type == "STRING":
167
+ value += str(variable.value)
168
+ elif variable.type == "JSON":
169
+ value += json.dumps(variable.value, indent=4)
170
+ else:
171
+ raise PromptCompilationError(
172
+ f"Input variable '{child_block.input_variable}' must be of type STRING or JSON"
173
+ )
174
+ else:
175
+ raise PromptCompilationError(f"Invalid child block_type for RICH_TEXT: {child_block.block_type}")
176
+
177
+ return CompiledValuePromptBlock(content=StringVellumValue(value=value), cache_config=block.cache_config)
178
+
179
+
180
+ def _sanitize_inputs(inputs: list[PromptRequestInput]) -> list[PromptRequestInput]:
181
+ sanitized_inputs: list[PromptRequestInput] = []
182
+ for input_ in inputs:
183
+ if input_.type == "CHAT_HISTORY" and input_.value is None:
184
+ sanitized_inputs.append(input_.model_copy(update={"value": cast(list[ChatMessage], [])}))
185
+ elif input_.type == "STRING" and input_.value is None:
186
+ sanitized_inputs.append(input_.model_copy(update={"value": ""}))
187
+ else:
188
+ sanitized_inputs.append(input_)
189
+
190
+ return sanitized_inputs
@@ -0,0 +1,2 @@
1
+ class PromptCompilationError(Exception):
2
+ pass
File without changes
@@ -0,0 +1,110 @@
1
+ import pytest
2
+
3
+ from vellum import (
4
+ ChatMessagePromptBlock,
5
+ JinjaPromptBlock,
6
+ PlainTextPromptBlock,
7
+ PromptRequestStringInput,
8
+ RichTextPromptBlock,
9
+ StringVellumValue,
10
+ VariablePromptBlock,
11
+ VellumVariable,
12
+ )
13
+ from vellum.prompts.blocks.compilation import compile_prompt_blocks
14
+ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
15
+
16
+
17
+ @pytest.mark.parametrize(
18
+ ["blocks", "inputs", "input_variables", "expected"],
19
+ [
20
+ # Empty
21
+ ([], [], [], []),
22
+ # Jinja
23
+ (
24
+ [JinjaPromptBlock(template="Hello, world!")],
25
+ [],
26
+ [],
27
+ [
28
+ CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
29
+ ],
30
+ ),
31
+ (
32
+ [JinjaPromptBlock(template="Repeat back to me {{ echo }}")],
33
+ [PromptRequestStringInput(key="echo", value="Hello, world!")],
34
+ [VellumVariable(id="1", type="STRING", key="echo")],
35
+ [
36
+ CompiledValuePromptBlock(content=StringVellumValue(value="Repeat back to me Hello, world!")),
37
+ ],
38
+ ),
39
+ # Rich Text
40
+ (
41
+ [
42
+ RichTextPromptBlock(
43
+ blocks=[
44
+ PlainTextPromptBlock(text="Hello, world!"),
45
+ ]
46
+ )
47
+ ],
48
+ [],
49
+ [],
50
+ [
51
+ CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
52
+ ],
53
+ ),
54
+ (
55
+ [
56
+ RichTextPromptBlock(
57
+ blocks=[
58
+ PlainTextPromptBlock(text='Repeat back to me "'),
59
+ VariablePromptBlock(input_variable="echo"),
60
+ PlainTextPromptBlock(text='".'),
61
+ ]
62
+ )
63
+ ],
64
+ [PromptRequestStringInput(key="echo", value="Hello, world!")],
65
+ [VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
66
+ [
67
+ CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".')),
68
+ ],
69
+ ),
70
+ # Chat Message
71
+ (
72
+ [
73
+ ChatMessagePromptBlock(
74
+ chat_role="USER",
75
+ blocks=[
76
+ RichTextPromptBlock(
77
+ blocks=[
78
+ PlainTextPromptBlock(text='Repeat back to me "'),
79
+ VariablePromptBlock(input_variable="echo"),
80
+ PlainTextPromptBlock(text='".'),
81
+ ]
82
+ )
83
+ ],
84
+ )
85
+ ],
86
+ [PromptRequestStringInput(key="echo", value="Hello, world!")],
87
+ [VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
88
+ [
89
+ CompiledChatMessagePromptBlock(
90
+ role="USER",
91
+ blocks=[
92
+ CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".'))
93
+ ],
94
+ ),
95
+ ],
96
+ ),
97
+ ],
98
+ ids=[
99
+ "empty",
100
+ "jinja-no-variables",
101
+ "jinja-with-variables",
102
+ "rich-text-no-variables",
103
+ "rich-text-with-variables",
104
+ "chat-message",
105
+ ],
106
+ )
107
+ def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected):
108
+ actual = compile_prompt_blocks(blocks=blocks, inputs=inputs, input_variables=input_variables)
109
+
110
+ assert actual == expected
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Annotated, Literal, Optional, Union
4
+
5
+ from vellum import ArrayVellumValue, ChatMessageRole, EphemeralPromptCacheConfig, VellumValue
6
+ from vellum.client.core import UniversalBaseModel
7
+
8
+
9
+ class BaseCompiledPromptBlock(UniversalBaseModel):
10
+ cache_config: Optional[EphemeralPromptCacheConfig] = None
11
+
12
+
13
+ class CompiledValuePromptBlock(BaseCompiledPromptBlock):
14
+ block_type: Literal["VALUE"] = "VALUE"
15
+ content: VellumValue
16
+
17
+
18
+ class CompiledChatMessagePromptBlock(BaseCompiledPromptBlock):
19
+ block_type: Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
20
+ role: ChatMessageRole = "ASSISTANT"
21
+ unterminated: bool = False
22
+ blocks: list[CompiledValuePromptBlock] = []
23
+ source: Optional[str] = None
24
+
25
+
26
+ CompiledPromptBlock = Annotated[
27
+ Union[
28
+ CompiledValuePromptBlock,
29
+ CompiledChatMessagePromptBlock,
30
+ ],
31
+ "block_type",
32
+ ]
33
+
34
+ ArrayVellumValue.model_rebuild()
35
+
36
+ CompiledValuePromptBlock.model_rebuild()
File without changes
File without changes
@@ -0,0 +1,28 @@
1
+ import datetime
2
+ import itertools
3
+ import json
4
+ import random
5
+ import re
6
+ from typing import Any, Callable, Dict, Union
7
+
8
+ import dateutil
9
+ import pydash
10
+ import pytz
11
+ import yaml
12
+
13
+ from vellum.utils.templating.custom_filters import is_valid_json_string
14
+
15
+ DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
16
+ "datetime": datetime,
17
+ "dateutil": dateutil,
18
+ "itertools": itertools,
19
+ "json": json,
20
+ "pydash": pydash,
21
+ "pytz": pytz,
22
+ "random": random,
23
+ "re": re,
24
+ "yaml": yaml,
25
+ }
26
+ DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
27
+ "is_valid_json_string": is_valid_json_string,
28
+ }
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Optional, Union
3
3
 
4
4
  from jinja2.sandbox import SandboxedEnvironment
5
5
 
6
- from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
6
+ from vellum.utils.templating.exceptions import JinjaTemplateError
7
7
  from vellum.workflows.state.encoder import DefaultStateEncoder
8
8
 
9
9
 
@@ -1,7 +1,5 @@
1
1
  from .base import BaseNode
2
- from .base_subworkflow_node import BaseSubworkflowNode
3
2
 
4
3
  __all__ = [
5
4
  "BaseNode",
6
- "BaseSubworkflowNode",
7
5
  ]
@@ -344,8 +344,8 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
344
344
  all_inputs = {}
345
345
  for key, value in inputs.items():
346
346
  path_parts = key.split(".")
347
- node_attribute_discriptor = getattr(self.__class__, path_parts[0])
348
- inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:], node_attribute_discriptor)
347
+ node_attribute_descriptor = getattr(self.__class__, path_parts[0])
348
+ inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:], node_attribute_descriptor)
349
349
  all_inputs[inputs_key] = value
350
350
 
351
351
  self._inputs = MappingProxyType(all_inputs)
@@ -355,7 +355,3 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
355
355
 
356
356
  def __repr__(self) -> str:
357
357
  return str(self.__class__)
358
-
359
-
360
- class MyNode2(BaseNode):
361
- pass
@@ -1,12 +1,14 @@
1
- from typing import TYPE_CHECKING, Generic, Iterator, Optional, Set, Type, TypeVar
1
+ from typing import TYPE_CHECKING, ClassVar, Generic, Iterator, Optional, Set, Type, TypeVar, Union
2
2
 
3
3
  from vellum.workflows.context import execution_context, get_parent_context
4
4
  from vellum.workflows.errors.types import WorkflowErrorCode
5
5
  from vellum.workflows.exceptions import NodeException
6
- from vellum.workflows.nodes.bases.base_subworkflow_node import BaseSubworkflowNode
6
+ from vellum.workflows.inputs.base import BaseInputs
7
+ from vellum.workflows.nodes.bases.base import BaseNode
7
8
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
8
9
  from vellum.workflows.state.base import BaseState
9
10
  from vellum.workflows.state.context import WorkflowContext
11
+ from vellum.workflows.types.core import EntityInputsInterface
10
12
  from vellum.workflows.types.generics import StateType, WorkflowInputsType
11
13
 
12
14
  if TYPE_CHECKING:
@@ -15,7 +17,7 @@ if TYPE_CHECKING:
15
17
  InnerStateType = TypeVar("InnerStateType", bound=BaseState)
16
18
 
17
19
 
18
- class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, WorkflowInputsType, InnerStateType]):
20
+ class InlineSubworkflowNode(BaseNode[StateType], Generic[StateType, WorkflowInputsType, InnerStateType]):
19
21
  """
20
22
  Used to execute a Subworkflow defined inline.
21
23
 
@@ -24,14 +26,13 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
24
26
  """
25
27
 
26
28
  subworkflow: Type["BaseWorkflow[WorkflowInputsType, InnerStateType]"]
29
+ subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs]] = {}
27
30
 
28
31
  def run(self) -> Iterator[BaseOutput]:
29
32
  with execution_context(parent_context=get_parent_context() or self._context.parent_context):
30
33
  subworkflow = self.subworkflow(
31
34
  parent_state=self.state,
32
- context=WorkflowContext(
33
- _vellum_client=self._context._vellum_client,
34
- ),
35
+ context=WorkflowContext(vellum_client=self._context.vellum_client),
35
36
  )
36
37
  subworkflow_stream = subworkflow.stream(
37
38
  inputs=self._compile_subworkflow_inputs(),
@@ -68,4 +69,9 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
68
69
 
69
70
  def _compile_subworkflow_inputs(self) -> WorkflowInputsType:
70
71
  inputs_class = self.subworkflow.get_inputs_class()
71
- return inputs_class(**self.subworkflow_inputs)
72
+ if isinstance(self.subworkflow_inputs, dict):
73
+ return inputs_class(**self.subworkflow_inputs)
74
+ elif isinstance(self.subworkflow_inputs, inputs_class):
75
+ return self.subworkflow_inputs
76
+ else:
77
+ raise ValueError(f"Invalid subworkflow inputs type: {type(self.subworkflow_inputs)}")
@@ -0,0 +1,41 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.inputs.base import BaseInputs
4
+ from vellum.workflows.nodes.bases.base import BaseNode
5
+ from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
6
+ from vellum.workflows.outputs.base import BaseOutput
7
+ from vellum.workflows.state.base import BaseState
8
+ from vellum.workflows.workflows.base import BaseWorkflow
9
+
10
+
11
+ class Inputs(BaseInputs):
12
+ foo: str
13
+
14
+
15
+ class MyInnerNode(BaseNode):
16
+ class Outputs(BaseNode.Outputs):
17
+ out = Inputs.foo
18
+
19
+
20
+ class MySubworkflow(BaseWorkflow[Inputs, BaseState]):
21
+ graph = MyInnerNode
22
+
23
+ class Outputs(BaseWorkflow.Outputs):
24
+ out = MyInnerNode.Outputs.out
25
+
26
+
27
+ @pytest.mark.parametrize("inputs", [{"foo": "bar"}, Inputs(foo="bar")])
28
+ def test_inline_subworkflow_node__inputs(inputs):
29
+ # GIVEN a node setup with subworkflow inputs
30
+ class MyNode(InlineSubworkflowNode):
31
+ subworkflow = MySubworkflow
32
+ subworkflow_inputs = inputs
33
+
34
+ # WHEN the node is run
35
+ node = MyNode()
36
+ events = list(node.run())
37
+
38
+ # THEN the output is as expected
39
+ assert events == [
40
+ BaseOutput(name="out", value="bar"),
41
+ ]
@@ -108,7 +108,7 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
108
108
  self._run_subworkflow(item=item, index=index)
109
109
 
110
110
  def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
111
- context = WorkflowContext(_vellum_client=self._context._vellum_client)
111
+ context = WorkflowContext(vellum_client=self._context.vellum_client)
112
112
  subworkflow = self.subworkflow(parent_state=self.state, context=context)
113
113
  events = subworkflow.stream(
114
114
  inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items),
@@ -1,42 +1,17 @@
1
- import datetime
2
- import itertools
3
1
  import json
4
- import random
5
- import re
6
2
  from typing import Any, Callable, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, Union, get_args
7
3
 
8
- import dateutil.parser
9
- import pydash
10
- import pytz
11
- import yaml
12
-
4
+ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS
5
+ from vellum.utils.templating.exceptions import JinjaTemplateError
6
+ from vellum.utils.templating.render import render_sandboxed_jinja_template
13
7
  from vellum.workflows.errors import WorkflowErrorCode
14
8
  from vellum.workflows.exceptions import NodeException
15
9
  from vellum.workflows.nodes.bases import BaseNode
16
10
  from vellum.workflows.nodes.bases.base import BaseNodeMeta
17
- from vellum.workflows.nodes.core.templating_node.custom_filters import is_valid_json_string
18
- from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
19
- from vellum.workflows.nodes.core.templating_node.render import render_sandboxed_jinja_template
20
- from vellum.workflows.types.core import EntityInputsInterface
11
+ from vellum.workflows.types.core import EntityInputsInterface, Json
21
12
  from vellum.workflows.types.generics import StateType
22
13
  from vellum.workflows.types.utils import get_original_base
23
14
 
24
- _DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
25
- "datetime": datetime,
26
- "dateutil": dateutil,
27
- "itertools": itertools,
28
- "json": json,
29
- "pydash": pydash,
30
- "pytz": pytz,
31
- "random": random,
32
- "re": re,
33
- "yaml": yaml,
34
- }
35
-
36
- _DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
37
- "is_valid_json_string": is_valid_json_string,
38
- }
39
-
40
15
  _OutputType = TypeVar("_OutputType")
41
16
 
42
17
 
@@ -78,8 +53,8 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
78
53
  # The inputs to render the template with.
79
54
  inputs: ClassVar[EntityInputsInterface]
80
55
 
81
- jinja_globals: Dict[str, Any] = _DEFAULT_JINJA_GLOBALS
82
- jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = _DEFAULT_JINJA_CUSTOM_FILTERS
56
+ jinja_globals: Dict[str, Any] = DEFAULT_JINJA_GLOBALS
57
+ jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = DEFAULT_JINJA_CUSTOM_FILTERS
83
58
 
84
59
  class Outputs(BaseNode.Outputs):
85
60
  """
@@ -113,6 +88,12 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
113
88
  if output_type is bool:
114
89
  return bool(rendered_template)
115
90
 
91
+ if output_type is Json:
92
+ try:
93
+ return json.loads(rendered_template)
94
+ except json.JSONDecodeError:
95
+ raise ValueError("Invalid JSON format for rendered_template")
96
+
116
97
  raise ValueError(f"Unsupported output type: {output_type}")
117
98
 
118
99
  def run(self) -> Outputs: