vellum-ai 0.14.41__py3-none-any.whl → 0.14.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +4 -4
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/reference.md +110 -3
- vellum/client/resources/documents/client.py +0 -6
- vellum/client/resources/prompts/client.py +228 -1
- vellum/client/types/__init__.py +4 -4
- vellum/client/types/deployment_read.py +2 -2
- vellum/client/types/execute_api_response.py +3 -4
- vellum/client/types/execute_api_response_json.py +7 -0
- vellum/client/types/{workflow_event_display_context.py → prompt_push_response.py} +4 -12
- vellum/client/types/prompt_settings.py +1 -0
- vellum/client/types/workflow_event_execution_read.py +0 -4
- vellum/client/types/workflow_execution_initiated_body.py +0 -9
- vellum/client/types/workflow_execution_initiated_event.py +0 -4
- vellum/client/types/workflow_execution_span.py +0 -4
- vellum/types/{node_event_display_context.py → execute_api_response_json.py} +1 -1
- vellum/types/{workflow_event_display_context.py → prompt_push_response.py} +1 -1
- vellum/workflows/inputs/base.py +26 -3
- vellum/workflows/inputs/tests/test_inputs.py +15 -0
- vellum/workflows/nodes/bases/base.py +0 -3
- vellum/workflows/nodes/bases/base_adornment_node.py +9 -0
- vellum/workflows/nodes/bases/tests/test_base_adornment_node.py +31 -0
- vellum/workflows/nodes/core/map_node/node.py +3 -2
- vellum/workflows/nodes/core/map_node/tests/test_node.py +56 -0
- vellum/workflows/nodes/core/retry_node/node.py +2 -1
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +62 -13
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +177 -0
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +3 -6
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +18 -15
- vellum/workflows/nodes/utils.py +14 -1
- vellum/workflows/references/output.py +1 -1
- vellum/workflows/references/workflow_input.py +5 -1
- vellum/workflows/runner/runner.py +2 -0
- vellum/workflows/workflows/base.py +5 -0
- {vellum_ai-0.14.41.dist-info → vellum_ai-0.14.43.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.41.dist-info → vellum_ai-0.14.43.dist-info}/RECORD +68 -67
- vellum_cli/pull.py +7 -0
- vellum_cli/tests/test_pull.py +23 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +32 -23
- vellum_ee/workflows/display/nodes/vellum/api_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/conditional_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/final_output_node.py +6 -6
- vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/map_node.py +15 -12
- vellum_ee/workflows/display/nodes/vellum/merge_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/search_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +1 -0
- vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +138 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +3 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -2
- vellum/client/types/node_event_display_context.py +0 -30
- {vellum_ai-0.14.41.dist-info → vellum_ai-0.14.43.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.41.dist-info → vellum_ai-0.14.43.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.41.dist-info → vellum_ai-0.14.43.dist-info}/entry_points.txt +0 -0
@@ -1,19 +1,14 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
3
|
from ..core.pydantic_utilities import UniversalBaseModel
|
5
|
-
from .node_event_display_context import NodeEventDisplayContext
|
6
|
-
from .workflow_event_display_context import WorkflowEventDisplayContext
|
7
4
|
from .vellum_code_resource_definition import VellumCodeResourceDefinition
|
8
5
|
import typing
|
9
6
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
10
7
|
import pydantic
|
11
|
-
from ..core.pydantic_utilities import update_forward_refs
|
12
8
|
|
13
9
|
|
14
10
|
class WorkflowExecutionInitiatedBody(UniversalBaseModel):
|
15
11
|
workflow_definition: VellumCodeResourceDefinition
|
16
|
-
display_context: typing.Optional[WorkflowEventDisplayContext] = None
|
17
12
|
inputs: typing.Dict[str, typing.Optional[typing.Any]]
|
18
13
|
|
19
14
|
if IS_PYDANTIC_V2:
|
@@ -24,7 +19,3 @@ class WorkflowExecutionInitiatedBody(UniversalBaseModel):
|
|
24
19
|
frozen = True
|
25
20
|
smart_union = True
|
26
21
|
extra = pydantic.Extra.allow
|
27
|
-
|
28
|
-
|
29
|
-
update_forward_refs(NodeEventDisplayContext, WorkflowExecutionInitiatedBody=WorkflowExecutionInitiatedBody)
|
30
|
-
update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionInitiatedBody=WorkflowExecutionInitiatedBody)
|
@@ -9,8 +9,6 @@ from .span_link import SpanLink
|
|
9
9
|
from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
|
10
10
|
from .workflow_parent_context import WorkflowParentContext
|
11
11
|
from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
|
12
|
-
from .node_event_display_context import NodeEventDisplayContext
|
13
|
-
from .workflow_event_display_context import WorkflowEventDisplayContext
|
14
12
|
import typing
|
15
13
|
from .parent_context import ParentContext
|
16
14
|
from .workflow_execution_initiated_body import WorkflowExecutionInitiatedBody
|
@@ -49,5 +47,3 @@ update_forward_refs(SpanLink, WorkflowExecutionInitiatedEvent=WorkflowExecutionI
|
|
49
47
|
update_forward_refs(WorkflowDeploymentParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
50
48
|
update_forward_refs(WorkflowParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
51
49
|
update_forward_refs(WorkflowSandboxParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
52
|
-
update_forward_refs(NodeEventDisplayContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
53
|
-
update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
@@ -3,12 +3,10 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
from ..core.pydantic_utilities import UniversalBaseModel
|
5
5
|
from .api_request_parent_context import ApiRequestParentContext
|
6
|
-
from .node_event_display_context import NodeEventDisplayContext
|
7
6
|
from .node_parent_context import NodeParentContext
|
8
7
|
from .prompt_deployment_parent_context import PromptDeploymentParentContext
|
9
8
|
from .span_link import SpanLink
|
10
9
|
from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
|
11
|
-
from .workflow_event_display_context import WorkflowEventDisplayContext
|
12
10
|
from .workflow_parent_context import WorkflowParentContext
|
13
11
|
from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
|
14
12
|
import typing
|
@@ -40,11 +38,9 @@ class WorkflowExecutionSpan(UniversalBaseModel):
|
|
40
38
|
|
41
39
|
|
42
40
|
update_forward_refs(ApiRequestParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
43
|
-
update_forward_refs(NodeEventDisplayContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
44
41
|
update_forward_refs(NodeParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
45
42
|
update_forward_refs(PromptDeploymentParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
46
43
|
update_forward_refs(SpanLink, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
47
44
|
update_forward_refs(WorkflowDeploymentParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
48
|
-
update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
49
45
|
update_forward_refs(WorkflowParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
50
46
|
update_forward_refs(WorkflowSandboxParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
vellum/workflows/inputs/base.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Iterator, Tuple, Type, Union, get_args, get_origin
|
1
|
+
from typing import Any, Dict, Iterator, Set, Tuple, Type, Union, get_args, get_origin
|
2
2
|
from typing_extensions import dataclass_transform
|
3
3
|
|
4
4
|
from pydantic import GetCoreSchemaHandler
|
@@ -14,11 +14,28 @@ from vellum.workflows.types.utils import get_class_attr_names, infer_types
|
|
14
14
|
|
15
15
|
@dataclass_transform(kw_only_default=True)
|
16
16
|
class _BaseInputsMeta(type):
|
17
|
+
def __new__(cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
18
|
+
dct["__parent_class__"] = type(None)
|
19
|
+
return super().__new__(cls, name, bases, dct)
|
20
|
+
|
17
21
|
def __getattribute__(cls, name: str) -> Any:
|
18
|
-
if
|
22
|
+
if name.startswith("_") or not issubclass(cls, BaseInputs):
|
23
|
+
return super().__getattribute__(name)
|
24
|
+
|
25
|
+
attr_names = get_class_attr_names(cls)
|
26
|
+
if name in attr_names:
|
27
|
+
# We first try to resolve the instance that this class attribute name is mapped to. If it's not found,
|
28
|
+
# we iterate through its inheritance hierarchy to find the first base class that has this attribute
|
29
|
+
# and use its mapping.
|
19
30
|
instance = vars(cls).get(name, undefined)
|
20
|
-
|
31
|
+
if instance is undefined:
|
32
|
+
for base in cls.__mro__[1:]:
|
33
|
+
inherited_input_reference = getattr(base, name, undefined)
|
34
|
+
if isinstance(inherited_input_reference, (ExternalInputReference, WorkflowInputReference)):
|
35
|
+
instance = inherited_input_reference.instance
|
36
|
+
break
|
21
37
|
|
38
|
+
types = infer_types(cls, name)
|
22
39
|
if getattr(cls, "__descriptor_class__", None) is ExternalInputReference:
|
23
40
|
return ExternalInputReference(name=name, types=types, instance=instance, inputs_class=cls)
|
24
41
|
else:
|
@@ -29,14 +46,20 @@ class _BaseInputsMeta(type):
|
|
29
46
|
def __iter__(cls) -> Iterator[InputReference]:
|
30
47
|
# We iterate through the inheritance hierarchy to find all the WorkflowInputReference attached to this
|
31
48
|
# Inputs class. __mro__ is the method resolution order, which is the order in which base classes are resolved.
|
49
|
+
yielded_attr_names: Set[str] = set()
|
50
|
+
|
32
51
|
for resolved_cls in cls.__mro__:
|
33
52
|
attr_names = get_class_attr_names(resolved_cls)
|
34
53
|
for attr_name in attr_names:
|
54
|
+
if attr_name in yielded_attr_names:
|
55
|
+
continue
|
56
|
+
|
35
57
|
attr_value = getattr(resolved_cls, attr_name)
|
36
58
|
if not isinstance(attr_value, (WorkflowInputReference, ExternalInputReference)):
|
37
59
|
continue
|
38
60
|
|
39
61
|
yield attr_value
|
62
|
+
yielded_attr_names.add(attr_name)
|
40
63
|
|
41
64
|
|
42
65
|
class BaseInputs(metaclass=_BaseInputsMeta):
|
@@ -47,3 +47,18 @@ def test_base_inputs_with_default():
|
|
47
47
|
|
48
48
|
# THEN it should use the default value
|
49
49
|
assert inputs.string_with_default == "default_value"
|
50
|
+
|
51
|
+
|
52
|
+
def test_base_inputs__supports_inherited_inputs():
|
53
|
+
# GIVEN an inputs class
|
54
|
+
class TopInputs(BaseInputs):
|
55
|
+
first: str
|
56
|
+
|
57
|
+
# WHEN we inherit from the base inputs class
|
58
|
+
class BottomInputs(TopInputs):
|
59
|
+
second: int
|
60
|
+
|
61
|
+
# THEN both references should be available
|
62
|
+
assert BottomInputs.first.name == "first"
|
63
|
+
assert BottomInputs.second.name == "second"
|
64
|
+
assert len([ref for ref in BottomInputs]) == 2
|
@@ -41,9 +41,6 @@ def is_nested_class(nested: Any, parent: Type) -> bool:
|
|
41
41
|
|
42
42
|
class BaseNodeMeta(type):
|
43
43
|
def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
44
|
-
# TODO: Inherit the inner Output classes from every base class.
|
45
|
-
# https://app.shortcut.com/vellum/story/4007/support-auto-inheriting-parent-node-outputs
|
46
|
-
|
47
44
|
if "Outputs" in dct:
|
48
45
|
outputs_class = dct["Outputs"]
|
49
46
|
if not any(issubclass(base, BaseOutputs) for base in outputs_class.__bases__):
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, Type
|
2
2
|
|
3
|
+
from vellum.workflows.inputs.base import BaseInputs
|
3
4
|
from vellum.workflows.nodes.bases.base import BaseNode, BaseNodeMeta
|
4
5
|
from vellum.workflows.outputs.base import BaseOutputs
|
5
6
|
from vellum.workflows.references.output import OutputReference
|
@@ -13,6 +14,14 @@ class _BaseAdornmentNodeMeta(BaseNodeMeta):
|
|
13
14
|
def __new__(cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
14
15
|
node_class = super().__new__(cls, name, bases, dct)
|
15
16
|
|
17
|
+
SubworkflowInputs = dct.get("SubworkflowInputs")
|
18
|
+
if (
|
19
|
+
isinstance(SubworkflowInputs, type)
|
20
|
+
and issubclass(SubworkflowInputs, BaseInputs)
|
21
|
+
and SubworkflowInputs.__parent_class__ is type(None)
|
22
|
+
):
|
23
|
+
SubworkflowInputs.__parent_class__ = node_class
|
24
|
+
|
16
25
|
subworkflow_attribute = dct.get("subworkflow")
|
17
26
|
if not subworkflow_attribute:
|
18
27
|
return node_class
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from typing import Callable, Type
|
2
|
+
|
3
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
4
|
+
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
5
|
+
from vellum.workflows.nodes.utils import create_adornment
|
6
|
+
|
7
|
+
|
8
|
+
def test_base_adornment_node__output_references_of_same_name():
|
9
|
+
# GIVEN a custom adornment node
|
10
|
+
class CustomAdornmentNode(BaseAdornmentNode):
|
11
|
+
@classmethod
|
12
|
+
def wrap(cls) -> Callable[..., Type["CustomAdornmentNode"]]:
|
13
|
+
return create_adornment(cls)
|
14
|
+
|
15
|
+
# AND two nodes wrapped by the adornment with the same output
|
16
|
+
@CustomAdornmentNode.wrap()
|
17
|
+
class AppleNode(BaseNode):
|
18
|
+
class Outputs(BaseNode.Outputs):
|
19
|
+
fruit: str
|
20
|
+
|
21
|
+
@CustomAdornmentNode.wrap()
|
22
|
+
class OrangeNode(BaseNode):
|
23
|
+
class Outputs(BaseNode.Outputs):
|
24
|
+
fruit: str
|
25
|
+
|
26
|
+
# WHEN get output references of these outputs
|
27
|
+
apple_output_reference = AppleNode.Outputs.fruit
|
28
|
+
orange_output_reference = OrangeNode.Outputs.fruit
|
29
|
+
|
30
|
+
# THEN the output references should not be equal
|
31
|
+
assert apple_output_reference != orange_output_reference, "Output references should not be equal"
|
@@ -62,7 +62,7 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
62
62
|
|
63
63
|
item: MapNodeItemType # type: ignore[valid-type]
|
64
64
|
index: int
|
65
|
-
|
65
|
+
items: List[MapNodeItemType] # type: ignore[valid-type]
|
66
66
|
|
67
67
|
def run(self) -> Iterator[BaseOutput]:
|
68
68
|
mapped_items: Dict[str, List] = defaultdict(list)
|
@@ -176,8 +176,9 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
176
176
|
parent_state=self.state,
|
177
177
|
context=context,
|
178
178
|
)
|
179
|
+
SubworkflowInputsClass = self.subworkflow.get_inputs_class()
|
179
180
|
events = subworkflow.stream(
|
180
|
-
inputs=
|
181
|
+
inputs=SubworkflowInputsClass(index=index, item=item, items=self.items),
|
181
182
|
node_output_mocks=self._context._get_all_node_output_mocks(),
|
182
183
|
event_filter=all_workflow_event_filter,
|
183
184
|
)
|
@@ -116,3 +116,59 @@ def test_map_node__inner_try():
|
|
116
116
|
# THEN the workflow should succeed
|
117
117
|
assert outputs[-1].name == "final_output"
|
118
118
|
assert len(outputs[-1].value) == 2
|
119
|
+
|
120
|
+
|
121
|
+
def test_map_node__nested_map_node():
|
122
|
+
# GIVEN the inner map node's inputs
|
123
|
+
class VegetableMapNodeInputs(MapNode.SubworkflowInputs):
|
124
|
+
item: str
|
125
|
+
|
126
|
+
# AND the outer map node's inputs
|
127
|
+
class FruitMapNodeInputs(MapNode.SubworkflowInputs):
|
128
|
+
item: str
|
129
|
+
|
130
|
+
# AND a simple node that concats both attributes
|
131
|
+
class SimpleConcatNode(BaseNode):
|
132
|
+
fruit = FruitMapNodeInputs.item
|
133
|
+
vegetable = VegetableMapNodeInputs.item
|
134
|
+
|
135
|
+
class Outputs(BaseNode.Outputs):
|
136
|
+
medley: str
|
137
|
+
|
138
|
+
def run(self) -> Outputs:
|
139
|
+
return self.Outputs(medley=f"{self.fruit} {self.vegetable}")
|
140
|
+
|
141
|
+
# AND a workflow using that node
|
142
|
+
class VegetableMapNodeWorkflow(BaseWorkflow[VegetableMapNodeInputs, BaseState]):
|
143
|
+
graph = SimpleConcatNode
|
144
|
+
|
145
|
+
class Outputs(BaseWorkflow.Outputs):
|
146
|
+
final_output = SimpleConcatNode.Outputs.medley
|
147
|
+
|
148
|
+
# AND an inner map node referencing that workflow
|
149
|
+
class VegetableMapNode(MapNode):
|
150
|
+
items = ["carrot", "potato"]
|
151
|
+
subworkflow = VegetableMapNodeWorkflow
|
152
|
+
|
153
|
+
# AND an outer subworkflow referencing the inner map node
|
154
|
+
class FruitMapNodeWorkflow(BaseWorkflow[FruitMapNodeInputs, BaseState]):
|
155
|
+
graph = VegetableMapNode
|
156
|
+
|
157
|
+
class Outputs(BaseWorkflow.Outputs):
|
158
|
+
final_output = VegetableMapNode.Outputs.final_output
|
159
|
+
|
160
|
+
# AND an outer map node referencing the outer subworkflow
|
161
|
+
class FruitMapNode(MapNode):
|
162
|
+
items = ["apple", "banana"]
|
163
|
+
subworkflow = FruitMapNodeWorkflow
|
164
|
+
|
165
|
+
# WHEN we run the workflow
|
166
|
+
stream = FruitMapNode().run()
|
167
|
+
outputs = list(stream)
|
168
|
+
|
169
|
+
# THEN the workflow should succeed
|
170
|
+
assert outputs[-1].name == "final_output"
|
171
|
+
assert outputs[-1].value == [
|
172
|
+
["apple carrot", "apple potato"],
|
173
|
+
["banana carrot", "banana potato"],
|
174
|
+
]
|
@@ -47,8 +47,9 @@ class RetryNode(BaseAdornmentNode[StateType], Generic[StateType]):
|
|
47
47
|
parent_state=self.state,
|
48
48
|
context=WorkflowContext.create_from(self._context),
|
49
49
|
)
|
50
|
+
inputs_class = subworkflow.get_inputs_class()
|
50
51
|
subworkflow_stream = subworkflow.stream(
|
51
|
-
inputs=
|
52
|
+
inputs=inputs_class(attempt_number=attempt_number),
|
52
53
|
event_filter=all_workflow_event_filter,
|
53
54
|
node_output_mocks=self._context._get_all_node_output_mocks(),
|
54
55
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from uuid import uuid4
|
3
|
-
from typing import Callable, ClassVar, Generic, Iterator, List, Optional, Set, Tuple, Union
|
3
|
+
from typing import Callable, ClassVar, Generator, Generic, Iterator, List, Optional, Set, Tuple, Union
|
4
4
|
|
5
5
|
from vellum import (
|
6
6
|
AdHocExecutePromptEvent,
|
@@ -8,6 +8,7 @@ from vellum import (
|
|
8
8
|
ChatMessage,
|
9
9
|
FunctionDefinition,
|
10
10
|
PromptBlock,
|
11
|
+
PromptOutput,
|
11
12
|
PromptParameters,
|
12
13
|
PromptRequestChatHistoryInput,
|
13
14
|
PromptRequestInput,
|
@@ -15,17 +16,19 @@ from vellum import (
|
|
15
16
|
PromptRequestStringInput,
|
16
17
|
VellumVariable,
|
17
18
|
)
|
18
|
-
from vellum.client import RequestOptions
|
19
|
+
from vellum.client import ApiError, RequestOptions
|
19
20
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
20
21
|
from vellum.client.types.prompt_settings import PromptSettings
|
21
22
|
from vellum.client.types.rich_text_child_block import RichTextChildBlock
|
22
23
|
from vellum.workflows.constants import OMIT
|
23
24
|
from vellum.workflows.context import get_execution_context
|
24
25
|
from vellum.workflows.errors import WorkflowErrorCode
|
26
|
+
from vellum.workflows.errors.types import vellum_error_to_workflow_error
|
25
27
|
from vellum.workflows.events.types import default_serializer
|
26
28
|
from vellum.workflows.exceptions import NodeException
|
27
29
|
from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
|
28
30
|
from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
|
31
|
+
from vellum.workflows.outputs import BaseOutput
|
29
32
|
from vellum.workflows.types import MergeBehavior
|
30
33
|
from vellum.workflows.types.generics import StateType
|
31
34
|
from vellum.workflows.utils.functions import compile_function_definition
|
@@ -103,17 +106,63 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
103
106
|
else None
|
104
107
|
)
|
105
108
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
109
|
+
if self.settings and not self.settings.stream_enabled:
|
110
|
+
# This endpoint is returning a single event, so we need to wrap it in a generator
|
111
|
+
# to match the existing interface.
|
112
|
+
response = self._context.vellum_client.ad_hoc.adhoc_execute_prompt(
|
113
|
+
ml_model=self.ml_model,
|
114
|
+
input_values=input_values,
|
115
|
+
input_variables=input_variables,
|
116
|
+
parameters=self.parameters,
|
117
|
+
blocks=self.blocks,
|
118
|
+
settings=self.settings,
|
119
|
+
functions=normalized_functions,
|
120
|
+
expand_meta=self.expand_meta,
|
121
|
+
request_options=request_options,
|
122
|
+
)
|
123
|
+
return iter([response])
|
124
|
+
else:
|
125
|
+
return self._context.vellum_client.ad_hoc.adhoc_execute_prompt_stream(
|
126
|
+
ml_model=self.ml_model,
|
127
|
+
input_values=input_values,
|
128
|
+
input_variables=input_variables,
|
129
|
+
parameters=self.parameters,
|
130
|
+
blocks=self.blocks,
|
131
|
+
settings=self.settings,
|
132
|
+
functions=normalized_functions,
|
133
|
+
expand_meta=self.expand_meta,
|
134
|
+
request_options=request_options,
|
135
|
+
)
|
136
|
+
|
137
|
+
def _process_prompt_event_stream(self) -> Generator[BaseOutput, None, Optional[List[PromptOutput]]]:
|
138
|
+
self._validate()
|
139
|
+
try:
|
140
|
+
prompt_event_stream = self._get_prompt_event_stream()
|
141
|
+
except ApiError as e:
|
142
|
+
self._handle_api_error(e)
|
143
|
+
|
144
|
+
if not self.settings or (self.settings and self.settings.stream_enabled):
|
145
|
+
# We don't use the INITIATED event anyway, so we can just skip it
|
146
|
+
# and use the exception handling to catch other api level errors
|
147
|
+
try:
|
148
|
+
next(prompt_event_stream)
|
149
|
+
except ApiError as e:
|
150
|
+
self._handle_api_error(e)
|
151
|
+
|
152
|
+
outputs: Optional[List[PromptOutput]] = None
|
153
|
+
for event in prompt_event_stream:
|
154
|
+
if event.state == "INITIATED":
|
155
|
+
continue
|
156
|
+
elif event.state == "STREAMING":
|
157
|
+
yield BaseOutput(name="results", delta=event.output.value)
|
158
|
+
elif event.state == "FULFILLED":
|
159
|
+
outputs = event.outputs
|
160
|
+
yield BaseOutput(name="results", value=event.outputs)
|
161
|
+
elif event.state == "REJECTED":
|
162
|
+
workflow_error = vellum_error_to_workflow_error(event.error)
|
163
|
+
raise NodeException.of(workflow_error)
|
164
|
+
|
165
|
+
return outputs
|
117
166
|
|
118
167
|
def _compile_prompt_inputs(self) -> Tuple[List[VellumVariable], List[PromptRequestInput]]:
|
119
168
|
input_variables: List[VellumVariable] = []
|
vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py
CHANGED
@@ -5,11 +5,14 @@ from uuid import uuid4
|
|
5
5
|
from typing import Any, Iterator, List
|
6
6
|
|
7
7
|
from vellum import (
|
8
|
+
AdHocExecutePromptEvent,
|
8
9
|
ChatMessagePromptBlock,
|
10
|
+
FulfilledAdHocExecutePromptEvent,
|
9
11
|
JinjaPromptBlock,
|
10
12
|
PlainTextPromptBlock,
|
11
13
|
PromptBlock,
|
12
14
|
PromptParameters,
|
15
|
+
PromptSettings,
|
13
16
|
RichTextPromptBlock,
|
14
17
|
VariablePromptBlock,
|
15
18
|
)
|
@@ -296,3 +299,177 @@ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
|
|
296
299
|
request_options=mock.ANY,
|
297
300
|
settings=None,
|
298
301
|
)
|
302
|
+
|
303
|
+
|
304
|
+
def test_inline_prompt_node__streaming_disabled(vellum_adhoc_prompt_client):
|
305
|
+
# GIVEN an InlinePromptNode
|
306
|
+
class Inputs(BaseInputs):
|
307
|
+
input: str
|
308
|
+
|
309
|
+
class State(BaseState):
|
310
|
+
pass
|
311
|
+
|
312
|
+
# AND it has streaming disabled
|
313
|
+
class MyInlinePromptNode(InlinePromptNode):
|
314
|
+
ml_model = "gpt-4o"
|
315
|
+
blocks = []
|
316
|
+
parameters = PromptParameters(
|
317
|
+
stop=[],
|
318
|
+
temperature=0.0,
|
319
|
+
max_tokens=4096,
|
320
|
+
top_p=1.0,
|
321
|
+
top_k=0,
|
322
|
+
frequency_penalty=0.0,
|
323
|
+
presence_penalty=0.0,
|
324
|
+
logit_bias=None,
|
325
|
+
custom_parameters={},
|
326
|
+
)
|
327
|
+
settings = PromptSettings(stream_enabled=False)
|
328
|
+
|
329
|
+
# AND a known response from invoking an inline prompt
|
330
|
+
expected_output: list[PromptOutput] = [StringVellumValue(value="Hello, world!")]
|
331
|
+
|
332
|
+
def generate_prompt_event(*args: Any, **kwargs: Any) -> AdHocExecutePromptEvent:
|
333
|
+
execution_id = str(uuid4())
|
334
|
+
return FulfilledAdHocExecutePromptEvent(
|
335
|
+
execution_id=execution_id,
|
336
|
+
outputs=expected_output,
|
337
|
+
)
|
338
|
+
|
339
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt.side_effect = generate_prompt_event
|
340
|
+
|
341
|
+
# WHEN the node is run
|
342
|
+
node = MyInlinePromptNode()
|
343
|
+
outputs = [o for o in node.run()]
|
344
|
+
|
345
|
+
# THEN the node should have produced the outputs we expect
|
346
|
+
result_output = outputs[0]
|
347
|
+
assert result_output.name == "results"
|
348
|
+
assert result_output.value == expected_output
|
349
|
+
|
350
|
+
# AND we should have made the expected call to Vellum search
|
351
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt.assert_called_once_with(
|
352
|
+
blocks=[],
|
353
|
+
expand_meta=Ellipsis,
|
354
|
+
functions=None,
|
355
|
+
input_values=[],
|
356
|
+
input_variables=[],
|
357
|
+
ml_model="gpt-4o",
|
358
|
+
parameters=PromptParameters(
|
359
|
+
stop=[],
|
360
|
+
temperature=0.0,
|
361
|
+
max_tokens=4096,
|
362
|
+
top_p=1.0,
|
363
|
+
top_k=0,
|
364
|
+
frequency_penalty=0.0,
|
365
|
+
presence_penalty=0.0,
|
366
|
+
logit_bias=None,
|
367
|
+
custom_parameters={},
|
368
|
+
),
|
369
|
+
request_options=mock.ANY,
|
370
|
+
settings=PromptSettings(stream_enabled=False),
|
371
|
+
)
|
372
|
+
|
373
|
+
|
374
|
+
def test_inline_prompt_node__json_output_with_streaming_disabled(vellum_adhoc_prompt_client):
|
375
|
+
# GIVEN an InlinePromptNode
|
376
|
+
class Inputs(BaseInputs):
|
377
|
+
input: str
|
378
|
+
|
379
|
+
class State(BaseState):
|
380
|
+
pass
|
381
|
+
|
382
|
+
class MyInlinePromptNode(InlinePromptNode):
|
383
|
+
ml_model = "gpt-4o"
|
384
|
+
blocks = []
|
385
|
+
parameters = PromptParameters(
|
386
|
+
stop=[],
|
387
|
+
temperature=0.0,
|
388
|
+
max_tokens=4096,
|
389
|
+
top_p=1.0,
|
390
|
+
top_k=0,
|
391
|
+
frequency_penalty=0.0,
|
392
|
+
presence_penalty=0.0,
|
393
|
+
logit_bias=None,
|
394
|
+
custom_parameters={
|
395
|
+
"json_mode": False,
|
396
|
+
"json_schema": {
|
397
|
+
"name": "get_result",
|
398
|
+
"schema": {
|
399
|
+
"type": "object",
|
400
|
+
"required": ["result"],
|
401
|
+
"properties": {"result": {"type": "string", "description": ""}},
|
402
|
+
},
|
403
|
+
},
|
404
|
+
},
|
405
|
+
)
|
406
|
+
settings = PromptSettings(stream_enabled=False)
|
407
|
+
|
408
|
+
# AND a known JSON response from invoking an inline prompt
|
409
|
+
expected_json = {"result": "Hello, world!"}
|
410
|
+
expected_outputs: List[PromptOutput] = [
|
411
|
+
StringVellumValue(value=json.dumps(expected_json)),
|
412
|
+
]
|
413
|
+
|
414
|
+
def generate_prompt_event(*args: Any, **kwargs: Any) -> AdHocExecutePromptEvent:
|
415
|
+
execution_id = str(uuid4())
|
416
|
+
return FulfilledAdHocExecutePromptEvent(
|
417
|
+
execution_id=execution_id,
|
418
|
+
outputs=expected_outputs,
|
419
|
+
)
|
420
|
+
|
421
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt.side_effect = generate_prompt_event
|
422
|
+
|
423
|
+
# WHEN the node is run
|
424
|
+
node = MyInlinePromptNode(
|
425
|
+
state=State(
|
426
|
+
meta=StateMeta(workflow_inputs=Inputs(input="Generate JSON.")),
|
427
|
+
)
|
428
|
+
)
|
429
|
+
outputs = [o for o in node.run()]
|
430
|
+
|
431
|
+
# THEN the node should have produced the outputs we expect
|
432
|
+
results_output = outputs[0]
|
433
|
+
assert results_output.name == "results"
|
434
|
+
assert results_output.value == expected_outputs
|
435
|
+
|
436
|
+
text_output = outputs[1]
|
437
|
+
assert text_output.name == "text"
|
438
|
+
assert text_output.value == '{"result": "Hello, world!"}'
|
439
|
+
|
440
|
+
json_output = outputs[2]
|
441
|
+
assert json_output.name == "json"
|
442
|
+
assert json_output.value == expected_json
|
443
|
+
|
444
|
+
# AND we should have made the expected call to Vellum search
|
445
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt.assert_called_once_with(
|
446
|
+
blocks=[],
|
447
|
+
expand_meta=Ellipsis,
|
448
|
+
functions=None,
|
449
|
+
input_values=[],
|
450
|
+
input_variables=[],
|
451
|
+
ml_model="gpt-4o",
|
452
|
+
parameters=PromptParameters(
|
453
|
+
stop=[],
|
454
|
+
temperature=0.0,
|
455
|
+
max_tokens=4096,
|
456
|
+
top_p=1.0,
|
457
|
+
top_k=0,
|
458
|
+
frequency_penalty=0.0,
|
459
|
+
presence_penalty=0.0,
|
460
|
+
logit_bias=None,
|
461
|
+
custom_parameters={
|
462
|
+
"json_mode": False,
|
463
|
+
"json_schema": {
|
464
|
+
"name": "get_result",
|
465
|
+
"schema": {
|
466
|
+
"type": "object",
|
467
|
+
"required": ["result"],
|
468
|
+
"properties": {"result": {"type": "string", "description": ""}},
|
469
|
+
},
|
470
|
+
},
|
471
|
+
},
|
472
|
+
),
|
473
|
+
request_options=mock.ANY,
|
474
|
+
settings=PromptSettings(stream_enabled=False),
|
475
|
+
)
|
@@ -8,11 +8,7 @@ from vellum.workflows.exceptions import NodeException
|
|
8
8
|
from vellum.workflows.graph.graph import Graph
|
9
9
|
from vellum.workflows.inputs.base import BaseInputs
|
10
10
|
from vellum.workflows.nodes.bases import BaseNode
|
11
|
-
from vellum.workflows.nodes.experimental.tool_calling_node.utils import
|
12
|
-
ToolRouterNode,
|
13
|
-
create_function_node,
|
14
|
-
create_tool_router_node,
|
15
|
-
)
|
11
|
+
from vellum.workflows.nodes.experimental.tool_calling_node.utils import create_function_node, create_tool_router_node
|
16
12
|
from vellum.workflows.outputs.base import BaseOutputs
|
17
13
|
from vellum.workflows.state.base import BaseState
|
18
14
|
from vellum.workflows.state.context import WorkflowContext
|
@@ -69,7 +65,7 @@ class ToolCallingNode(BaseNode):
|
|
69
65
|
graph = self._graph
|
70
66
|
|
71
67
|
class Outputs(BaseWorkflow.Outputs):
|
72
|
-
text: str =
|
68
|
+
text: str = self.tool_router_node.Outputs.text
|
73
69
|
chat_history: List[ChatMessage] = ToolCallingState.chat_history
|
74
70
|
|
75
71
|
subworkflow = ToolCallingWorkflow(
|
@@ -107,6 +103,7 @@ class ToolCallingNode(BaseNode):
|
|
107
103
|
self._function_nodes = {
|
108
104
|
function.__name__: create_function_node(
|
109
105
|
function=function,
|
106
|
+
tool_router_node=self.tool_router_node,
|
110
107
|
)
|
111
108
|
for function in self.functions
|
112
109
|
}
|