vellum-ai 0.12.7__py3-none-any.whl → 0.12.9__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/evaluations/resources.py +2 -2
  3. vellum/prompts/__init__.py +0 -0
  4. vellum/prompts/blocks/__init__.py +0 -0
  5. vellum/prompts/blocks/compilation.py +190 -0
  6. vellum/prompts/blocks/exceptions.py +2 -0
  7. vellum/prompts/blocks/tests/__init__.py +0 -0
  8. vellum/prompts/blocks/tests/test_compilation.py +110 -0
  9. vellum/prompts/blocks/types.py +36 -0
  10. vellum/utils/__init__.py +0 -0
  11. vellum/utils/templating/__init__.py +0 -0
  12. vellum/utils/templating/constants.py +28 -0
  13. vellum/{workflows/nodes/core/templating_node → utils/templating}/render.py +1 -1
  14. vellum/workflows/nodes/bases/__init__.py +0 -2
  15. vellum/workflows/nodes/bases/base.py +2 -6
  16. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +13 -7
  17. vellum/workflows/nodes/core/inline_subworkflow_node/tests/__init__.py +0 -0
  18. vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +41 -0
  19. vellum/workflows/nodes/core/map_node/node.py +1 -1
  20. vellum/workflows/nodes/core/templating_node/node.py +12 -31
  21. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +66 -0
  22. vellum/workflows/nodes/core/try_node/node.py +1 -3
  23. vellum/workflows/nodes/core/try_node/tests/test_node.py +1 -1
  24. vellum/workflows/nodes/displayable/bases/api_node/node.py +2 -2
  25. vellum/workflows/nodes/displayable/bases/search_node.py +5 -2
  26. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +4 -2
  27. vellum/workflows/nodes/experimental/README.md +6 -0
  28. vellum/workflows/nodes/experimental/__init__.py +0 -0
  29. vellum/workflows/nodes/experimental/openai_chat_completion_node/__init__.py +5 -0
  30. vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +260 -0
  31. vellum/workflows/sandbox.py +7 -5
  32. vellum/workflows/state/context.py +5 -4
  33. vellum/workflows/tests/test_sandbox.py +2 -2
  34. vellum/workflows/utils/tests/test_vellum_variables.py +3 -0
  35. vellum/workflows/utils/vellum_variables.py +5 -4
  36. vellum/workflows/workflows/base.py +5 -2
  37. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/METADATA +2 -1
  38. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/RECORD +48 -33
  39. vellum_cli/tests/test_push.py +1 -1
  40. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +6 -1
  41. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +207 -0
  42. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +0 -5
  43. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +0 -10
  44. /vellum/{workflows/nodes/core/templating_node → utils/templating}/custom_filters.py +0 -0
  45. /vellum/{workflows/nodes/core/templating_node → utils/templating}/exceptions.py +0 -0
  46. /vellum/{evaluations/utils → utils}/typing.py +0 -0
  47. /vellum/{evaluations/utils → utils}/uuid.py +0 -0
  48. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/LICENSE +0 -0
  49. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/WHEEL +0 -0
  50. {vellum_ai-0.12.7.dist-info → vellum_ai-0.12.9.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.12.7",
21
+ "X-Fern-SDK-Version": "0.12.9",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -12,8 +12,6 @@ from vellum.evaluations.constants import DEFAULT_MAX_POLLING_DURATION_MS, DEFAUL
12
12
  from vellum.evaluations.exceptions import TestSuiteRunResultsException
13
13
  from vellum.evaluations.utils.env import get_api_key
14
14
  from vellum.evaluations.utils.paginator import PaginatedResults, get_all_results
15
- from vellum.evaluations.utils.typing import cast_not_optional
16
- from vellum.evaluations.utils.uuid import is_valid_uuid
17
15
  from vellum.types import (
18
16
  ExternalTestCaseExecutionRequest,
19
17
  NamedTestCaseVariableValueRequest,
@@ -24,6 +22,8 @@ from vellum.types import (
24
22
  TestSuiteRunMetricOutput,
25
23
  TestSuiteRunState,
26
24
  )
25
+ from vellum.utils.typing import cast_not_optional
26
+ from vellum.utils.uuid import is_valid_uuid
27
27
 
28
28
  logger = logging.getLogger(__name__)
29
29
 
File without changes
File without changes
@@ -0,0 +1,190 @@
1
+ import json
2
+ from typing import Optional, cast
3
+
4
+ from vellum import (
5
+ ChatMessage,
6
+ JsonVellumValue,
7
+ PromptBlock,
8
+ PromptRequestInput,
9
+ RichTextPromptBlock,
10
+ StringVellumValue,
11
+ VellumVariable,
12
+ )
13
+ from vellum.prompts.blocks.exceptions import PromptCompilationError
14
+ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledPromptBlock, CompiledValuePromptBlock
15
+ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS
16
+ from vellum.utils.templating.render import render_sandboxed_jinja_template
17
+ from vellum.utils.typing import cast_not_optional
18
+
19
+
20
+ def compile_prompt_blocks(
21
+ blocks: list[PromptBlock],
22
+ inputs: list[PromptRequestInput],
23
+ input_variables: list[VellumVariable],
24
+ ) -> list[CompiledPromptBlock]:
25
+ """Compiles a list of Prompt Blocks, performing all variable substitutions and Jinja templating needed."""
26
+
27
+ sanitized_inputs = _sanitize_inputs(inputs)
28
+
29
+ compiled_blocks: list[CompiledPromptBlock] = []
30
+ for block in blocks:
31
+ if block.state == "DISABLED":
32
+ continue
33
+
34
+ if block.block_type == "CHAT_MESSAGE":
35
+ chat_role = cast_not_optional(block.chat_role)
36
+ inner_blocks = cast_not_optional(block.blocks)
37
+ unterminated = block.chat_message_unterminated or False
38
+
39
+ inner_prompt_blocks = compile_prompt_blocks(
40
+ inner_blocks,
41
+ sanitized_inputs,
42
+ input_variables,
43
+ )
44
+ if not inner_prompt_blocks:
45
+ continue
46
+
47
+ compiled_blocks.append(
48
+ CompiledChatMessagePromptBlock(
49
+ role=chat_role,
50
+ unterminated=unterminated,
51
+ source=block.chat_source,
52
+ blocks=[inner for inner in inner_prompt_blocks if inner.block_type == "VALUE"],
53
+ cache_config=block.cache_config,
54
+ )
55
+ )
56
+
57
+ elif block.block_type == "JINJA":
58
+ if block.template is None:
59
+ continue
60
+
61
+ rendered_template = render_sandboxed_jinja_template(
62
+ template=block.template,
63
+ input_values={input_.key: input_.value for input_ in sanitized_inputs},
64
+ jinja_custom_filters=DEFAULT_JINJA_CUSTOM_FILTERS,
65
+ jinja_globals=DEFAULT_JINJA_CUSTOM_FILTERS,
66
+ )
67
+ jinja_content = StringVellumValue(value=rendered_template)
68
+
69
+ compiled_blocks.append(
70
+ CompiledValuePromptBlock(
71
+ content=jinja_content,
72
+ cache_config=block.cache_config,
73
+ )
74
+ )
75
+
76
+ elif block.block_type == "VARIABLE":
77
+ compiled_input: Optional[PromptRequestInput] = next(
78
+ (input_ for input_ in sanitized_inputs if input_.key == str(block.input_variable)), None
79
+ )
80
+ if compiled_input is None:
81
+ raise PromptCompilationError(f"Input variable '{block.input_variable}' not found")
82
+
83
+ if compiled_input.type == "CHAT_HISTORY":
84
+ history = cast(list[ChatMessage], compiled_input.value)
85
+ chat_message_blocks = _compile_chat_messages_as_prompt_blocks(history)
86
+ compiled_blocks.extend(chat_message_blocks)
87
+ continue
88
+
89
+ if compiled_input.type == "STRING":
90
+ compiled_blocks.append(
91
+ CompiledValuePromptBlock(
92
+ content=StringVellumValue(value=compiled_input.value),
93
+ cache_config=block.cache_config,
94
+ )
95
+ )
96
+ elif compiled_input == "JSON":
97
+ compiled_blocks.append(
98
+ CompiledValuePromptBlock(
99
+ content=JsonVellumValue(value=compiled_input.value),
100
+ cache_config=block.cache_config,
101
+ )
102
+ )
103
+ elif compiled_input.type == "CHAT_HISTORY":
104
+ chat_message_blocks = _compile_chat_messages_as_prompt_blocks(compiled_input.value)
105
+ compiled_blocks.extend(chat_message_blocks)
106
+ else:
107
+ raise PromptCompilationError(f"Invalid input type for variable block: {compiled_input.type}")
108
+
109
+ elif block.block_type == "RICH_TEXT":
110
+ value_block = _compile_rich_text_block_as_value_block(block=block, inputs=sanitized_inputs)
111
+ compiled_blocks.append(value_block)
112
+
113
+ elif block.block_type == "FUNCTION_DEFINITION":
114
+ raise PromptCompilationError("Function definitions shouldn't go through compilation process")
115
+ else:
116
+ raise PromptCompilationError(f"Unknown block_type: {block.block_type}")
117
+
118
+ return compiled_blocks
119
+
120
+
121
+ def _compile_chat_messages_as_prompt_blocks(chat_messages: list[ChatMessage]) -> list[CompiledChatMessagePromptBlock]:
122
+ blocks: list[CompiledChatMessagePromptBlock] = []
123
+ for chat_message in chat_messages:
124
+ if chat_message.content is None:
125
+ continue
126
+
127
+ chat_message_blocks = (
128
+ [
129
+ CompiledValuePromptBlock(
130
+ content=item,
131
+ )
132
+ for item in chat_message.content.value
133
+ ]
134
+ if chat_message.content.type == "ARRAY"
135
+ else [
136
+ CompiledValuePromptBlock(
137
+ content=chat_message.content,
138
+ )
139
+ ]
140
+ )
141
+
142
+ blocks.append(
143
+ CompiledChatMessagePromptBlock(
144
+ role=chat_message.role,
145
+ unterminated=False,
146
+ blocks=chat_message_blocks,
147
+ source=chat_message.source,
148
+ )
149
+ )
150
+
151
+ return blocks
152
+
153
+
154
+ def _compile_rich_text_block_as_value_block(
155
+ block: RichTextPromptBlock,
156
+ inputs: list[PromptRequestInput],
157
+ ) -> CompiledValuePromptBlock:
158
+ value: str = ""
159
+ for child_block in block.blocks:
160
+ if child_block.block_type == "PLAIN_TEXT":
161
+ value += child_block.text
162
+ elif child_block.block_type == "VARIABLE":
163
+ variable = next((input_ for input_ in inputs if input_.key == str(child_block.input_variable)), None)
164
+ if variable is None:
165
+ raise PromptCompilationError(f"Input variable '{child_block.input_variable}' not found")
166
+ elif variable.type == "STRING":
167
+ value += str(variable.value)
168
+ elif variable.type == "JSON":
169
+ value += json.dumps(variable.value, indent=4)
170
+ else:
171
+ raise PromptCompilationError(
172
+ f"Input variable '{child_block.input_variable}' must be of type STRING or JSON"
173
+ )
174
+ else:
175
+ raise PromptCompilationError(f"Invalid child block_type for RICH_TEXT: {child_block.block_type}")
176
+
177
+ return CompiledValuePromptBlock(content=StringVellumValue(value=value), cache_config=block.cache_config)
178
+
179
+
180
+ def _sanitize_inputs(inputs: list[PromptRequestInput]) -> list[PromptRequestInput]:
181
+ sanitized_inputs: list[PromptRequestInput] = []
182
+ for input_ in inputs:
183
+ if input_.type == "CHAT_HISTORY" and input_.value is None:
184
+ sanitized_inputs.append(input_.model_copy(update={"value": cast(list[ChatMessage], [])}))
185
+ elif input_.type == "STRING" and input_.value is None:
186
+ sanitized_inputs.append(input_.model_copy(update={"value": ""}))
187
+ else:
188
+ sanitized_inputs.append(input_)
189
+
190
+ return sanitized_inputs
@@ -0,0 +1,2 @@
1
+ class PromptCompilationError(Exception):
2
+ pass
File without changes
@@ -0,0 +1,110 @@
1
+ import pytest
2
+
3
+ from vellum import (
4
+ ChatMessagePromptBlock,
5
+ JinjaPromptBlock,
6
+ PlainTextPromptBlock,
7
+ PromptRequestStringInput,
8
+ RichTextPromptBlock,
9
+ StringVellumValue,
10
+ VariablePromptBlock,
11
+ VellumVariable,
12
+ )
13
+ from vellum.prompts.blocks.compilation import compile_prompt_blocks
14
+ from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
15
+
16
+
17
+ @pytest.mark.parametrize(
18
+ ["blocks", "inputs", "input_variables", "expected"],
19
+ [
20
+ # Empty
21
+ ([], [], [], []),
22
+ # Jinja
23
+ (
24
+ [JinjaPromptBlock(template="Hello, world!")],
25
+ [],
26
+ [],
27
+ [
28
+ CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
29
+ ],
30
+ ),
31
+ (
32
+ [JinjaPromptBlock(template="Repeat back to me {{ echo }}")],
33
+ [PromptRequestStringInput(key="echo", value="Hello, world!")],
34
+ [VellumVariable(id="1", type="STRING", key="echo")],
35
+ [
36
+ CompiledValuePromptBlock(content=StringVellumValue(value="Repeat back to me Hello, world!")),
37
+ ],
38
+ ),
39
+ # Rich Text
40
+ (
41
+ [
42
+ RichTextPromptBlock(
43
+ blocks=[
44
+ PlainTextPromptBlock(text="Hello, world!"),
45
+ ]
46
+ )
47
+ ],
48
+ [],
49
+ [],
50
+ [
51
+ CompiledValuePromptBlock(content=StringVellumValue(value="Hello, world!")),
52
+ ],
53
+ ),
54
+ (
55
+ [
56
+ RichTextPromptBlock(
57
+ blocks=[
58
+ PlainTextPromptBlock(text='Repeat back to me "'),
59
+ VariablePromptBlock(input_variable="echo"),
60
+ PlainTextPromptBlock(text='".'),
61
+ ]
62
+ )
63
+ ],
64
+ [PromptRequestStringInput(key="echo", value="Hello, world!")],
65
+ [VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
66
+ [
67
+ CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".')),
68
+ ],
69
+ ),
70
+ # Chat Message
71
+ (
72
+ [
73
+ ChatMessagePromptBlock(
74
+ chat_role="USER",
75
+ blocks=[
76
+ RichTextPromptBlock(
77
+ blocks=[
78
+ PlainTextPromptBlock(text='Repeat back to me "'),
79
+ VariablePromptBlock(input_variable="echo"),
80
+ PlainTextPromptBlock(text='".'),
81
+ ]
82
+ )
83
+ ],
84
+ )
85
+ ],
86
+ [PromptRequestStringInput(key="echo", value="Hello, world!")],
87
+ [VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="STRING", key="echo")],
88
+ [
89
+ CompiledChatMessagePromptBlock(
90
+ role="USER",
91
+ blocks=[
92
+ CompiledValuePromptBlock(content=StringVellumValue(value='Repeat back to me "Hello, world!".'))
93
+ ],
94
+ ),
95
+ ],
96
+ ),
97
+ ],
98
+ ids=[
99
+ "empty",
100
+ "jinja-no-variables",
101
+ "jinja-with-variables",
102
+ "rich-text-no-variables",
103
+ "rich-text-with-variables",
104
+ "chat-message",
105
+ ],
106
+ )
107
+ def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected):
108
+ actual = compile_prompt_blocks(blocks=blocks, inputs=inputs, input_variables=input_variables)
109
+
110
+ assert actual == expected
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Annotated, Literal, Optional, Union
4
+
5
+ from vellum import ArrayVellumValue, ChatMessageRole, EphemeralPromptCacheConfig, VellumValue
6
+ from vellum.client.core import UniversalBaseModel
7
+
8
+
9
+ class BaseCompiledPromptBlock(UniversalBaseModel):
10
+ cache_config: Optional[EphemeralPromptCacheConfig] = None
11
+
12
+
13
+ class CompiledValuePromptBlock(BaseCompiledPromptBlock):
14
+ block_type: Literal["VALUE"] = "VALUE"
15
+ content: VellumValue
16
+
17
+
18
+ class CompiledChatMessagePromptBlock(BaseCompiledPromptBlock):
19
+ block_type: Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
20
+ role: ChatMessageRole = "ASSISTANT"
21
+ unterminated: bool = False
22
+ blocks: list[CompiledValuePromptBlock] = []
23
+ source: Optional[str] = None
24
+
25
+
26
+ CompiledPromptBlock = Annotated[
27
+ Union[
28
+ CompiledValuePromptBlock,
29
+ CompiledChatMessagePromptBlock,
30
+ ],
31
+ "block_type",
32
+ ]
33
+
34
+ ArrayVellumValue.model_rebuild()
35
+
36
+ CompiledValuePromptBlock.model_rebuild()
File without changes
File without changes
@@ -0,0 +1,28 @@
1
+ import datetime
2
+ import itertools
3
+ import json
4
+ import random
5
+ import re
6
+ from typing import Any, Callable, Dict, Union
7
+
8
+ import dateutil
9
+ import pydash
10
+ import pytz
11
+ import yaml
12
+
13
+ from vellum.utils.templating.custom_filters import is_valid_json_string
14
+
15
+ DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
16
+ "datetime": datetime,
17
+ "dateutil": dateutil,
18
+ "itertools": itertools,
19
+ "json": json,
20
+ "pydash": pydash,
21
+ "pytz": pytz,
22
+ "random": random,
23
+ "re": re,
24
+ "yaml": yaml,
25
+ }
26
+ DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
27
+ "is_valid_json_string": is_valid_json_string,
28
+ }
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Optional, Union
3
3
 
4
4
  from jinja2.sandbox import SandboxedEnvironment
5
5
 
6
- from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
6
+ from vellum.utils.templating.exceptions import JinjaTemplateError
7
7
  from vellum.workflows.state.encoder import DefaultStateEncoder
8
8
 
9
9
 
@@ -1,7 +1,5 @@
1
1
  from .base import BaseNode
2
- from .base_subworkflow_node import BaseSubworkflowNode
3
2
 
4
3
  __all__ = [
5
4
  "BaseNode",
6
- "BaseSubworkflowNode",
7
5
  ]
@@ -344,8 +344,8 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
344
344
  all_inputs = {}
345
345
  for key, value in inputs.items():
346
346
  path_parts = key.split(".")
347
- node_attribute_discriptor = getattr(self.__class__, path_parts[0])
348
- inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:], node_attribute_discriptor)
347
+ node_attribute_descriptor = getattr(self.__class__, path_parts[0])
348
+ inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:], node_attribute_descriptor)
349
349
  all_inputs[inputs_key] = value
350
350
 
351
351
  self._inputs = MappingProxyType(all_inputs)
@@ -355,7 +355,3 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
355
355
 
356
356
  def __repr__(self) -> str:
357
357
  return str(self.__class__)
358
-
359
-
360
- class MyNode2(BaseNode):
361
- pass
@@ -1,12 +1,14 @@
1
- from typing import TYPE_CHECKING, Generic, Iterator, Optional, Set, Type, TypeVar
1
+ from typing import TYPE_CHECKING, ClassVar, Generic, Iterator, Optional, Set, Type, TypeVar, Union
2
2
 
3
3
  from vellum.workflows.context import execution_context, get_parent_context
4
4
  from vellum.workflows.errors.types import WorkflowErrorCode
5
5
  from vellum.workflows.exceptions import NodeException
6
- from vellum.workflows.nodes.bases.base_subworkflow_node import BaseSubworkflowNode
6
+ from vellum.workflows.inputs.base import BaseInputs
7
+ from vellum.workflows.nodes.bases.base import BaseNode
7
8
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
8
9
  from vellum.workflows.state.base import BaseState
9
10
  from vellum.workflows.state.context import WorkflowContext
11
+ from vellum.workflows.types.core import EntityInputsInterface
10
12
  from vellum.workflows.types.generics import StateType, WorkflowInputsType
11
13
 
12
14
  if TYPE_CHECKING:
@@ -15,7 +17,7 @@ if TYPE_CHECKING:
15
17
  InnerStateType = TypeVar("InnerStateType", bound=BaseState)
16
18
 
17
19
 
18
- class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, WorkflowInputsType, InnerStateType]):
20
+ class InlineSubworkflowNode(BaseNode[StateType], Generic[StateType, WorkflowInputsType, InnerStateType]):
19
21
  """
20
22
  Used to execute a Subworkflow defined inline.
21
23
 
@@ -24,14 +26,13 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
24
26
  """
25
27
 
26
28
  subworkflow: Type["BaseWorkflow[WorkflowInputsType, InnerStateType]"]
29
+ subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs]] = {}
27
30
 
28
31
  def run(self) -> Iterator[BaseOutput]:
29
32
  with execution_context(parent_context=get_parent_context() or self._context.parent_context):
30
33
  subworkflow = self.subworkflow(
31
34
  parent_state=self.state,
32
- context=WorkflowContext(
33
- _vellum_client=self._context._vellum_client,
34
- ),
35
+ context=WorkflowContext(vellum_client=self._context.vellum_client),
35
36
  )
36
37
  subworkflow_stream = subworkflow.stream(
37
38
  inputs=self._compile_subworkflow_inputs(),
@@ -68,4 +69,9 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
68
69
 
69
70
  def _compile_subworkflow_inputs(self) -> WorkflowInputsType:
70
71
  inputs_class = self.subworkflow.get_inputs_class()
71
- return inputs_class(**self.subworkflow_inputs)
72
+ if isinstance(self.subworkflow_inputs, dict):
73
+ return inputs_class(**self.subworkflow_inputs)
74
+ elif isinstance(self.subworkflow_inputs, inputs_class):
75
+ return self.subworkflow_inputs
76
+ else:
77
+ raise ValueError(f"Invalid subworkflow inputs type: {type(self.subworkflow_inputs)}")
@@ -0,0 +1,41 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.inputs.base import BaseInputs
4
+ from vellum.workflows.nodes.bases.base import BaseNode
5
+ from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
6
+ from vellum.workflows.outputs.base import BaseOutput
7
+ from vellum.workflows.state.base import BaseState
8
+ from vellum.workflows.workflows.base import BaseWorkflow
9
+
10
+
11
+ class Inputs(BaseInputs):
12
+ foo: str
13
+
14
+
15
+ class MyInnerNode(BaseNode):
16
+ class Outputs(BaseNode.Outputs):
17
+ out = Inputs.foo
18
+
19
+
20
+ class MySubworkflow(BaseWorkflow[Inputs, BaseState]):
21
+ graph = MyInnerNode
22
+
23
+ class Outputs(BaseWorkflow.Outputs):
24
+ out = MyInnerNode.Outputs.out
25
+
26
+
27
+ @pytest.mark.parametrize("inputs", [{"foo": "bar"}, Inputs(foo="bar")])
28
+ def test_inline_subworkflow_node__inputs(inputs):
29
+ # GIVEN a node setup with subworkflow inputs
30
+ class MyNode(InlineSubworkflowNode):
31
+ subworkflow = MySubworkflow
32
+ subworkflow_inputs = inputs
33
+
34
+ # WHEN the node is run
35
+ node = MyNode()
36
+ events = list(node.run())
37
+
38
+ # THEN the output is as expected
39
+ assert events == [
40
+ BaseOutput(name="out", value="bar"),
41
+ ]
@@ -108,7 +108,7 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
108
108
  self._run_subworkflow(item=item, index=index)
109
109
 
110
110
  def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
111
- context = WorkflowContext(_vellum_client=self._context._vellum_client)
111
+ context = WorkflowContext(vellum_client=self._context.vellum_client)
112
112
  subworkflow = self.subworkflow(parent_state=self.state, context=context)
113
113
  events = subworkflow.stream(
114
114
  inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items),
@@ -1,42 +1,17 @@
1
- import datetime
2
- import itertools
3
1
  import json
4
- import random
5
- import re
6
2
  from typing import Any, Callable, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, Union, get_args
7
3
 
8
- import dateutil.parser
9
- import pydash
10
- import pytz
11
- import yaml
12
-
4
+ from vellum.utils.templating.constants import DEFAULT_JINJA_CUSTOM_FILTERS, DEFAULT_JINJA_GLOBALS
5
+ from vellum.utils.templating.exceptions import JinjaTemplateError
6
+ from vellum.utils.templating.render import render_sandboxed_jinja_template
13
7
  from vellum.workflows.errors import WorkflowErrorCode
14
8
  from vellum.workflows.exceptions import NodeException
15
9
  from vellum.workflows.nodes.bases import BaseNode
16
10
  from vellum.workflows.nodes.bases.base import BaseNodeMeta
17
- from vellum.workflows.nodes.core.templating_node.custom_filters import is_valid_json_string
18
- from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
19
- from vellum.workflows.nodes.core.templating_node.render import render_sandboxed_jinja_template
20
- from vellum.workflows.types.core import EntityInputsInterface
11
+ from vellum.workflows.types.core import EntityInputsInterface, Json
21
12
  from vellum.workflows.types.generics import StateType
22
13
  from vellum.workflows.types.utils import get_original_base
23
14
 
24
- _DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
25
- "datetime": datetime,
26
- "dateutil": dateutil,
27
- "itertools": itertools,
28
- "json": json,
29
- "pydash": pydash,
30
- "pytz": pytz,
31
- "random": random,
32
- "re": re,
33
- "yaml": yaml,
34
- }
35
-
36
- _DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
37
- "is_valid_json_string": is_valid_json_string,
38
- }
39
-
40
15
  _OutputType = TypeVar("_OutputType")
41
16
 
42
17
 
@@ -78,8 +53,8 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
78
53
  # The inputs to render the template with.
79
54
  inputs: ClassVar[EntityInputsInterface]
80
55
 
81
- jinja_globals: Dict[str, Any] = _DEFAULT_JINJA_GLOBALS
82
- jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = _DEFAULT_JINJA_CUSTOM_FILTERS
56
+ jinja_globals: Dict[str, Any] = DEFAULT_JINJA_GLOBALS
57
+ jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = DEFAULT_JINJA_CUSTOM_FILTERS
83
58
 
84
59
  class Outputs(BaseNode.Outputs):
85
60
  """
@@ -113,6 +88,12 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
113
88
  if output_type is bool:
114
89
  return bool(rendered_template)
115
90
 
91
+ if output_type is Json:
92
+ try:
93
+ return json.loads(rendered_template)
94
+ except json.JSONDecodeError:
95
+ raise ValueError("Invalid JSON format for rendered_template")
96
+
116
97
  raise ValueError(f"Unsupported output type: {output_type}")
117
98
 
118
99
  def run(self) -> Outputs: