vellum-ai 0.14.40__py3-none-any.whl → 0.14.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +2 -4
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +141 -4
- vellum/client/resources/ad_hoc/client.py +311 -1
- vellum/client/resources/deployments/client.py +2 -2
- vellum/client/resources/documents/client.py +0 -6
- vellum/client/types/__init__.py +2 -4
- vellum/client/types/execute_api_response.py +3 -4
- vellum/client/types/execute_api_response_json.py +7 -0
- vellum/client/types/prompt_settings.py +1 -0
- vellum/client/types/workflow_event_execution_read.py +0 -4
- vellum/client/types/workflow_execution_initiated_body.py +0 -9
- vellum/client/types/workflow_execution_initiated_event.py +0 -4
- vellum/client/types/workflow_execution_span.py +0 -4
- vellum/types/{node_event_display_context.py → execute_api_response_json.py} +1 -1
- vellum/workflows/inputs/base.py +26 -3
- vellum/workflows/inputs/tests/test_inputs.py +15 -0
- vellum/workflows/nodes/bases/base_adornment_node.py +9 -0
- vellum/workflows/nodes/core/map_node/node.py +3 -2
- vellum/workflows/nodes/core/map_node/tests/test_node.py +56 -0
- vellum/workflows/nodes/core/retry_node/node.py +2 -1
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +6 -28
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +6 -10
- vellum/workflows/nodes/utils.py +14 -1
- vellum/workflows/references/workflow_input.py +5 -1
- vellum/workflows/runner/runner.py +2 -0
- vellum/workflows/workflows/base.py +5 -0
- {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/RECORD +65 -68
- vellum_ee/workflows/display/nodes/base_node_display.py +67 -28
- vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +18 -0
- vellum_ee/workflows/display/nodes/vellum/api_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/conditional_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/error_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/final_output_node.py +8 -8
- vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/map_node.py +15 -12
- vellum_ee/workflows/display/nodes/vellum/merge_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/note_node.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +3 -4
- vellum_ee/workflows/display/nodes/vellum/search_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/templating_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +138 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +3 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -2
- vellum_ee/workflows/display/workflows/base_workflow_display.py +4 -12
- vellum/client/types/node_event_display_context.py +0 -30
- vellum/client/types/workflow_event_display_context.py +0 -28
- vellum/types/workflow_event_display_context.py +0 -3
- vellum_ee/workflows/display/nodes/base_node_vellum_display.py +0 -40
- {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/entry_points.txt +0 -0
@@ -2,8 +2,9 @@
|
|
2
2
|
|
3
3
|
from ..core.pydantic_utilities import UniversalBaseModel
|
4
4
|
import typing_extensions
|
5
|
-
import
|
5
|
+
from .execute_api_response_json import ExecuteApiResponseJson
|
6
6
|
from ..core.serialization import FieldMetadata
|
7
|
+
import typing
|
7
8
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
9
|
import pydantic
|
9
10
|
|
@@ -11,9 +12,7 @@ import pydantic
|
|
11
12
|
class ExecuteApiResponse(UniversalBaseModel):
|
12
13
|
status_code: int
|
13
14
|
text: str
|
14
|
-
json_: typing_extensions.Annotated[
|
15
|
-
typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]], FieldMetadata(alias="json")
|
16
|
-
] = None
|
15
|
+
json_: typing_extensions.Annotated[ExecuteApiResponseJson, FieldMetadata(alias="json")]
|
17
16
|
headers: typing.Dict[str, str]
|
18
17
|
|
19
18
|
if IS_PYDANTIC_V2:
|
@@ -8,6 +8,7 @@ import pydantic
|
|
8
8
|
|
9
9
|
class PromptSettings(UniversalBaseModel):
|
10
10
|
timeout: typing.Optional[float] = None
|
11
|
+
stream_enabled: typing.Optional[bool] = None
|
11
12
|
|
12
13
|
if IS_PYDANTIC_V2:
|
13
14
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -10,8 +10,6 @@ from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
|
|
10
10
|
from .workflow_parent_context import WorkflowParentContext
|
11
11
|
from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
|
12
12
|
from .array_vellum_value import ArrayVellumValue
|
13
|
-
from .node_event_display_context import NodeEventDisplayContext
|
14
|
-
from .workflow_event_display_context import WorkflowEventDisplayContext
|
15
13
|
import typing
|
16
14
|
import datetime as dt
|
17
15
|
from .execution_vellum_value import ExecutionVellumValue
|
@@ -56,5 +54,3 @@ update_forward_refs(WorkflowDeploymentParentContext, WorkflowEventExecutionRead=
|
|
56
54
|
update_forward_refs(WorkflowParentContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
|
57
55
|
update_forward_refs(WorkflowSandboxParentContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
|
58
56
|
update_forward_refs(ArrayVellumValue, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
|
59
|
-
update_forward_refs(NodeEventDisplayContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
|
60
|
-
update_forward_refs(WorkflowEventDisplayContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
|
@@ -1,19 +1,14 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
3
|
from ..core.pydantic_utilities import UniversalBaseModel
|
5
|
-
from .node_event_display_context import NodeEventDisplayContext
|
6
|
-
from .workflow_event_display_context import WorkflowEventDisplayContext
|
7
4
|
from .vellum_code_resource_definition import VellumCodeResourceDefinition
|
8
5
|
import typing
|
9
6
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
10
7
|
import pydantic
|
11
|
-
from ..core.pydantic_utilities import update_forward_refs
|
12
8
|
|
13
9
|
|
14
10
|
class WorkflowExecutionInitiatedBody(UniversalBaseModel):
|
15
11
|
workflow_definition: VellumCodeResourceDefinition
|
16
|
-
display_context: typing.Optional[WorkflowEventDisplayContext] = None
|
17
12
|
inputs: typing.Dict[str, typing.Optional[typing.Any]]
|
18
13
|
|
19
14
|
if IS_PYDANTIC_V2:
|
@@ -24,7 +19,3 @@ class WorkflowExecutionInitiatedBody(UniversalBaseModel):
|
|
24
19
|
frozen = True
|
25
20
|
smart_union = True
|
26
21
|
extra = pydantic.Extra.allow
|
27
|
-
|
28
|
-
|
29
|
-
update_forward_refs(NodeEventDisplayContext, WorkflowExecutionInitiatedBody=WorkflowExecutionInitiatedBody)
|
30
|
-
update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionInitiatedBody=WorkflowExecutionInitiatedBody)
|
@@ -9,8 +9,6 @@ from .span_link import SpanLink
|
|
9
9
|
from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
|
10
10
|
from .workflow_parent_context import WorkflowParentContext
|
11
11
|
from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
|
12
|
-
from .node_event_display_context import NodeEventDisplayContext
|
13
|
-
from .workflow_event_display_context import WorkflowEventDisplayContext
|
14
12
|
import typing
|
15
13
|
from .parent_context import ParentContext
|
16
14
|
from .workflow_execution_initiated_body import WorkflowExecutionInitiatedBody
|
@@ -49,5 +47,3 @@ update_forward_refs(SpanLink, WorkflowExecutionInitiatedEvent=WorkflowExecutionI
|
|
49
47
|
update_forward_refs(WorkflowDeploymentParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
50
48
|
update_forward_refs(WorkflowParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
51
49
|
update_forward_refs(WorkflowSandboxParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
52
|
-
update_forward_refs(NodeEventDisplayContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
53
|
-
update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
|
@@ -3,12 +3,10 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
from ..core.pydantic_utilities import UniversalBaseModel
|
5
5
|
from .api_request_parent_context import ApiRequestParentContext
|
6
|
-
from .node_event_display_context import NodeEventDisplayContext
|
7
6
|
from .node_parent_context import NodeParentContext
|
8
7
|
from .prompt_deployment_parent_context import PromptDeploymentParentContext
|
9
8
|
from .span_link import SpanLink
|
10
9
|
from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
|
11
|
-
from .workflow_event_display_context import WorkflowEventDisplayContext
|
12
10
|
from .workflow_parent_context import WorkflowParentContext
|
13
11
|
from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
|
14
12
|
import typing
|
@@ -40,11 +38,9 @@ class WorkflowExecutionSpan(UniversalBaseModel):
|
|
40
38
|
|
41
39
|
|
42
40
|
update_forward_refs(ApiRequestParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
43
|
-
update_forward_refs(NodeEventDisplayContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
44
41
|
update_forward_refs(NodeParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
45
42
|
update_forward_refs(PromptDeploymentParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
46
43
|
update_forward_refs(SpanLink, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
47
44
|
update_forward_refs(WorkflowDeploymentParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
48
|
-
update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
49
45
|
update_forward_refs(WorkflowParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
50
46
|
update_forward_refs(WorkflowSandboxParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
|
vellum/workflows/inputs/base.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Iterator, Tuple, Type, Union, get_args, get_origin
|
1
|
+
from typing import Any, Dict, Iterator, Set, Tuple, Type, Union, get_args, get_origin
|
2
2
|
from typing_extensions import dataclass_transform
|
3
3
|
|
4
4
|
from pydantic import GetCoreSchemaHandler
|
@@ -14,11 +14,28 @@ from vellum.workflows.types.utils import get_class_attr_names, infer_types
|
|
14
14
|
|
15
15
|
@dataclass_transform(kw_only_default=True)
|
16
16
|
class _BaseInputsMeta(type):
|
17
|
+
def __new__(cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
18
|
+
dct["__parent_class__"] = type(None)
|
19
|
+
return super().__new__(cls, name, bases, dct)
|
20
|
+
|
17
21
|
def __getattribute__(cls, name: str) -> Any:
|
18
|
-
if
|
22
|
+
if name.startswith("_") or not issubclass(cls, BaseInputs):
|
23
|
+
return super().__getattribute__(name)
|
24
|
+
|
25
|
+
attr_names = get_class_attr_names(cls)
|
26
|
+
if name in attr_names:
|
27
|
+
# We first try to resolve the instance that this class attribute name is mapped to. If it's not found,
|
28
|
+
# we iterate through its inheritance hierarchy to find the first base class that has this attribute
|
29
|
+
# and use its mapping.
|
19
30
|
instance = vars(cls).get(name, undefined)
|
20
|
-
|
31
|
+
if instance is undefined:
|
32
|
+
for base in cls.__mro__[1:]:
|
33
|
+
inherited_input_reference = getattr(base, name, undefined)
|
34
|
+
if isinstance(inherited_input_reference, (ExternalInputReference, WorkflowInputReference)):
|
35
|
+
instance = inherited_input_reference.instance
|
36
|
+
break
|
21
37
|
|
38
|
+
types = infer_types(cls, name)
|
22
39
|
if getattr(cls, "__descriptor_class__", None) is ExternalInputReference:
|
23
40
|
return ExternalInputReference(name=name, types=types, instance=instance, inputs_class=cls)
|
24
41
|
else:
|
@@ -29,14 +46,20 @@ class _BaseInputsMeta(type):
|
|
29
46
|
def __iter__(cls) -> Iterator[InputReference]:
|
30
47
|
# We iterate through the inheritance hierarchy to find all the WorkflowInputReference attached to this
|
31
48
|
# Inputs class. __mro__ is the method resolution order, which is the order in which base classes are resolved.
|
49
|
+
yielded_attr_names: Set[str] = set()
|
50
|
+
|
32
51
|
for resolved_cls in cls.__mro__:
|
33
52
|
attr_names = get_class_attr_names(resolved_cls)
|
34
53
|
for attr_name in attr_names:
|
54
|
+
if attr_name in yielded_attr_names:
|
55
|
+
continue
|
56
|
+
|
35
57
|
attr_value = getattr(resolved_cls, attr_name)
|
36
58
|
if not isinstance(attr_value, (WorkflowInputReference, ExternalInputReference)):
|
37
59
|
continue
|
38
60
|
|
39
61
|
yield attr_value
|
62
|
+
yielded_attr_names.add(attr_name)
|
40
63
|
|
41
64
|
|
42
65
|
class BaseInputs(metaclass=_BaseInputsMeta):
|
@@ -47,3 +47,18 @@ def test_base_inputs_with_default():
|
|
47
47
|
|
48
48
|
# THEN it should use the default value
|
49
49
|
assert inputs.string_with_default == "default_value"
|
50
|
+
|
51
|
+
|
52
|
+
def test_base_inputs__supports_inherited_inputs():
|
53
|
+
# GIVEN an inputs class
|
54
|
+
class TopInputs(BaseInputs):
|
55
|
+
first: str
|
56
|
+
|
57
|
+
# WHEN we inherit from the base inputs class
|
58
|
+
class BottomInputs(TopInputs):
|
59
|
+
second: int
|
60
|
+
|
61
|
+
# THEN both references should be available
|
62
|
+
assert BottomInputs.first.name == "first"
|
63
|
+
assert BottomInputs.second.name == "second"
|
64
|
+
assert len([ref for ref in BottomInputs]) == 2
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, Type
|
2
2
|
|
3
|
+
from vellum.workflows.inputs.base import BaseInputs
|
3
4
|
from vellum.workflows.nodes.bases.base import BaseNode, BaseNodeMeta
|
4
5
|
from vellum.workflows.outputs.base import BaseOutputs
|
5
6
|
from vellum.workflows.references.output import OutputReference
|
@@ -13,6 +14,14 @@ class _BaseAdornmentNodeMeta(BaseNodeMeta):
|
|
13
14
|
def __new__(cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
14
15
|
node_class = super().__new__(cls, name, bases, dct)
|
15
16
|
|
17
|
+
SubworkflowInputs = dct.get("SubworkflowInputs")
|
18
|
+
if (
|
19
|
+
isinstance(SubworkflowInputs, type)
|
20
|
+
and issubclass(SubworkflowInputs, BaseInputs)
|
21
|
+
and SubworkflowInputs.__parent_class__ is type(None)
|
22
|
+
):
|
23
|
+
SubworkflowInputs.__parent_class__ = node_class
|
24
|
+
|
16
25
|
subworkflow_attribute = dct.get("subworkflow")
|
17
26
|
if not subworkflow_attribute:
|
18
27
|
return node_class
|
@@ -62,7 +62,7 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
62
62
|
|
63
63
|
item: MapNodeItemType # type: ignore[valid-type]
|
64
64
|
index: int
|
65
|
-
|
65
|
+
items: List[MapNodeItemType] # type: ignore[valid-type]
|
66
66
|
|
67
67
|
def run(self) -> Iterator[BaseOutput]:
|
68
68
|
mapped_items: Dict[str, List] = defaultdict(list)
|
@@ -176,8 +176,9 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
176
176
|
parent_state=self.state,
|
177
177
|
context=context,
|
178
178
|
)
|
179
|
+
SubworkflowInputsClass = self.subworkflow.get_inputs_class()
|
179
180
|
events = subworkflow.stream(
|
180
|
-
inputs=
|
181
|
+
inputs=SubworkflowInputsClass(index=index, item=item, items=self.items),
|
181
182
|
node_output_mocks=self._context._get_all_node_output_mocks(),
|
182
183
|
event_filter=all_workflow_event_filter,
|
183
184
|
)
|
@@ -116,3 +116,59 @@ def test_map_node__inner_try():
|
|
116
116
|
# THEN the workflow should succeed
|
117
117
|
assert outputs[-1].name == "final_output"
|
118
118
|
assert len(outputs[-1].value) == 2
|
119
|
+
|
120
|
+
|
121
|
+
def test_map_node__nested_map_node():
|
122
|
+
# GIVEN the inner map node's inputs
|
123
|
+
class VegetableMapNodeInputs(MapNode.SubworkflowInputs):
|
124
|
+
item: str
|
125
|
+
|
126
|
+
# AND the outer map node's inputs
|
127
|
+
class FruitMapNodeInputs(MapNode.SubworkflowInputs):
|
128
|
+
item: str
|
129
|
+
|
130
|
+
# AND a simple node that concats both attributes
|
131
|
+
class SimpleConcatNode(BaseNode):
|
132
|
+
fruit = FruitMapNodeInputs.item
|
133
|
+
vegetable = VegetableMapNodeInputs.item
|
134
|
+
|
135
|
+
class Outputs(BaseNode.Outputs):
|
136
|
+
medley: str
|
137
|
+
|
138
|
+
def run(self) -> Outputs:
|
139
|
+
return self.Outputs(medley=f"{self.fruit} {self.vegetable}")
|
140
|
+
|
141
|
+
# AND a workflow using that node
|
142
|
+
class VegetableMapNodeWorkflow(BaseWorkflow[VegetableMapNodeInputs, BaseState]):
|
143
|
+
graph = SimpleConcatNode
|
144
|
+
|
145
|
+
class Outputs(BaseWorkflow.Outputs):
|
146
|
+
final_output = SimpleConcatNode.Outputs.medley
|
147
|
+
|
148
|
+
# AND an inner map node referencing that workflow
|
149
|
+
class VegetableMapNode(MapNode):
|
150
|
+
items = ["carrot", "potato"]
|
151
|
+
subworkflow = VegetableMapNodeWorkflow
|
152
|
+
|
153
|
+
# AND an outer subworkflow referencing the inner map node
|
154
|
+
class FruitMapNodeWorkflow(BaseWorkflow[FruitMapNodeInputs, BaseState]):
|
155
|
+
graph = VegetableMapNode
|
156
|
+
|
157
|
+
class Outputs(BaseWorkflow.Outputs):
|
158
|
+
final_output = VegetableMapNode.Outputs.final_output
|
159
|
+
|
160
|
+
# AND an outer map node referencing the outer subworkflow
|
161
|
+
class FruitMapNode(MapNode):
|
162
|
+
items = ["apple", "banana"]
|
163
|
+
subworkflow = FruitMapNodeWorkflow
|
164
|
+
|
165
|
+
# WHEN we run the workflow
|
166
|
+
stream = FruitMapNode().run()
|
167
|
+
outputs = list(stream)
|
168
|
+
|
169
|
+
# THEN the workflow should succeed
|
170
|
+
assert outputs[-1].name == "final_output"
|
171
|
+
assert outputs[-1].value == [
|
172
|
+
["apple carrot", "apple potato"],
|
173
|
+
["banana carrot", "banana potato"],
|
174
|
+
]
|
@@ -47,8 +47,9 @@ class RetryNode(BaseAdornmentNode[StateType], Generic[StateType]):
|
|
47
47
|
parent_state=self.state,
|
48
48
|
context=WorkflowContext.create_from(self._context),
|
49
49
|
)
|
50
|
+
inputs_class = subworkflow.get_inputs_class()
|
50
51
|
subworkflow_stream = subworkflow.stream(
|
51
|
-
inputs=
|
52
|
+
inputs=inputs_class(attempt_number=attempt_number),
|
52
53
|
event_filter=all_workflow_event_filter,
|
53
54
|
node_output_mocks=self._context._get_all_node_output_mocks(),
|
54
55
|
)
|
@@ -1,8 +1,7 @@
|
|
1
1
|
from collections.abc import Callable
|
2
|
-
from typing import Any, ClassVar,
|
2
|
+
from typing import Any, ClassVar, List, Optional
|
3
3
|
|
4
|
-
from vellum import ChatMessage,
|
5
|
-
from vellum.client.types.chat_message_request import ChatMessageRequest
|
4
|
+
from vellum import ChatMessage, PromptBlock
|
6
5
|
from vellum.workflows.context import execution_context, get_parent_context
|
7
6
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
8
7
|
from vellum.workflows.exceptions import NodeException
|
@@ -35,8 +34,7 @@ class ToolCallingNode(BaseNode):
|
|
35
34
|
|
36
35
|
ml_model: ClassVar[str] = "gpt-4o-mini"
|
37
36
|
blocks: ClassVar[List[PromptBlock]] = []
|
38
|
-
functions: ClassVar[List[
|
39
|
-
function_callables: ClassVar[Dict[str, Callable[..., Any]]] = {}
|
37
|
+
functions: ClassVar[List[Callable[..., Any]]] = []
|
40
38
|
prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
|
41
39
|
# TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
|
42
40
|
max_tool_calls: ClassVar[int] = 1
|
@@ -59,27 +57,13 @@ class ToolCallingNode(BaseNode):
|
|
59
57
|
This dynamically builds a graph with router and function nodes,
|
60
58
|
then executes the workflow.
|
61
59
|
"""
|
62
|
-
self._validate_functions()
|
63
|
-
|
64
|
-
initial_chat_history = []
|
65
|
-
|
66
|
-
# Extract chat history from prompt inputs if available
|
67
|
-
if self.prompt_inputs and "chat_history" in self.prompt_inputs:
|
68
|
-
chat_history_input = self.prompt_inputs["chat_history"]
|
69
|
-
if isinstance(chat_history_input, list) and all(
|
70
|
-
isinstance(msg, (ChatMessage, ChatMessageRequest)) for msg in chat_history_input
|
71
|
-
):
|
72
|
-
initial_chat_history = [
|
73
|
-
msg if isinstance(msg, ChatMessage) else ChatMessage.model_validate(msg.model_dump())
|
74
|
-
for msg in chat_history_input
|
75
|
-
]
|
76
60
|
|
77
61
|
self._build_graph()
|
78
62
|
|
79
63
|
with execution_context(parent_context=get_parent_context()):
|
80
64
|
|
81
65
|
class ToolCallingState(BaseState):
|
82
|
-
chat_history: List[ChatMessage] =
|
66
|
+
chat_history: List[ChatMessage] = []
|
83
67
|
|
84
68
|
class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
|
85
69
|
graph = self._graph
|
@@ -121,9 +105,8 @@ class ToolCallingNode(BaseNode):
|
|
121
105
|
)
|
122
106
|
|
123
107
|
self._function_nodes = {
|
124
|
-
function.
|
108
|
+
function.__name__: create_function_node(
|
125
109
|
function=function,
|
126
|
-
function_callable=cast(Callable[..., Any], self.function_callables[function.name]), # type: ignore
|
127
110
|
)
|
128
111
|
for function in self.functions
|
129
112
|
}
|
@@ -132,7 +115,7 @@ class ToolCallingNode(BaseNode):
|
|
132
115
|
|
133
116
|
# Add connections from ports of router to function nodes and back to router
|
134
117
|
for function_name, FunctionNodeClass in self._function_nodes.items():
|
135
|
-
router_port = getattr(self.tool_router_node.Ports, function_name)
|
118
|
+
router_port = getattr(self.tool_router_node.Ports, function_name)
|
136
119
|
edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
|
137
120
|
graph_set.add(edge_graph)
|
138
121
|
|
@@ -140,8 +123,3 @@ class ToolCallingNode(BaseNode):
|
|
140
123
|
graph_set.add(default_port)
|
141
124
|
|
142
125
|
self._graph = Graph.from_set(graph_set)
|
143
|
-
|
144
|
-
def _validate_functions(self) -> None:
|
145
|
-
for function in self.functions:
|
146
|
-
if function.name is None:
|
147
|
-
raise ValueError("Function name is required")
|
@@ -35,7 +35,7 @@ class ToolRouterNode(InlinePromptNode):
|
|
35
35
|
if function_call is not None:
|
36
36
|
self.state.chat_history.append(
|
37
37
|
ChatMessage(
|
38
|
-
role="
|
38
|
+
role="ASSISTANT",
|
39
39
|
content=FunctionCallChatMessageContent(
|
40
40
|
value=FunctionCallChatMessageContentValue(
|
41
41
|
name=function_call.name,
|
@@ -51,16 +51,12 @@ class ToolRouterNode(InlinePromptNode):
|
|
51
51
|
def create_tool_router_node(
|
52
52
|
ml_model: str,
|
53
53
|
blocks: List[PromptBlock],
|
54
|
-
functions: List[
|
54
|
+
functions: List[Callable[..., Any]],
|
55
55
|
prompt_inputs: Optional[EntityInputsInterface],
|
56
56
|
) -> Type[ToolRouterNode]:
|
57
57
|
Ports = type("Ports", (), {})
|
58
58
|
for function in functions:
|
59
|
-
|
60
|
-
# We should not raise an error here since we filter out functions without names
|
61
|
-
raise ValueError("Function name is required")
|
62
|
-
|
63
|
-
function_name = function.name
|
59
|
+
function_name = function.__name__
|
64
60
|
port_condition = LazyReference(
|
65
61
|
lambda: (
|
66
62
|
ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
|
@@ -98,7 +94,7 @@ def create_tool_router_node(
|
|
98
94
|
return node
|
99
95
|
|
100
96
|
|
101
|
-
def create_function_node(function:
|
97
|
+
def create_function_node(function: Callable[..., Any]) -> Type[FunctionNode]:
|
102
98
|
"""
|
103
99
|
Create a FunctionNode class for a given function.
|
104
100
|
|
@@ -113,14 +109,14 @@ def create_function_node(function: FunctionDefinition, function_callable: Callab
|
|
113
109
|
arguments = outputs["arguments"]
|
114
110
|
|
115
111
|
# Call the original function directly with the arguments
|
116
|
-
result =
|
112
|
+
result = function(**arguments)
|
117
113
|
|
118
114
|
self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
|
119
115
|
|
120
116
|
return self.Outputs()
|
121
117
|
|
122
118
|
node = type(
|
123
|
-
f"FunctionNode_{function.
|
119
|
+
f"FunctionNode_{function.__name__}",
|
124
120
|
(FunctionNode,),
|
125
121
|
{
|
126
122
|
"function": function,
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -9,9 +9,11 @@ from pydantic import BaseModel, create_model
|
|
9
9
|
from vellum.client.types.function_call import FunctionCall
|
10
10
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
11
11
|
from vellum.workflows.exceptions import NodeException
|
12
|
+
from vellum.workflows.inputs.base import BaseInputs
|
12
13
|
from vellum.workflows.nodes import BaseNode
|
13
14
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
14
15
|
from vellum.workflows.ports.port import Port
|
16
|
+
from vellum.workflows.state.base import BaseState
|
15
17
|
from vellum.workflows.types.core import Json
|
16
18
|
from vellum.workflows.types.generics import NodeType
|
17
19
|
|
@@ -54,7 +56,18 @@ def create_adornment(
|
|
54
56
|
# https://app.shortcut.com/vellum/story/4116
|
55
57
|
from vellum.workflows import BaseWorkflow
|
56
58
|
|
57
|
-
|
59
|
+
SubworkflowInputs = getattr(adornable_cls, "SubworkflowInputs", None)
|
60
|
+
BaseSubworkflowInputs = (
|
61
|
+
SubworkflowInputs
|
62
|
+
if isinstance(SubworkflowInputs, type) and issubclass(SubworkflowInputs, BaseInputs)
|
63
|
+
else BaseInputs
|
64
|
+
)
|
65
|
+
|
66
|
+
# mypy is too conservative here - you can absolutely inherit from dynamic classes in python
|
67
|
+
class Inputs(BaseSubworkflowInputs): # type: ignore[misc, valid-type]
|
68
|
+
pass
|
69
|
+
|
70
|
+
class Subworkflow(BaseWorkflow[Inputs, BaseState]):
|
58
71
|
graph = inner_cls
|
59
72
|
|
60
73
|
class Outputs(BaseWorkflow.Outputs):
|
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Tuple, Type, TypeVar, cast
|
|
3
3
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
4
4
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
5
5
|
from vellum.workflows.exceptions import NodeException
|
6
|
+
from vellum.workflows.types.generics import import_workflow_class
|
6
7
|
|
7
8
|
if TYPE_CHECKING:
|
8
9
|
from vellum.workflows.inputs.base import BaseInputs
|
@@ -29,7 +30,10 @@ class WorkflowInputReference(BaseDescriptor[_InputType], Generic[_InputType]):
|
|
29
30
|
return self._inputs_class
|
30
31
|
|
31
32
|
def resolve(self, state: "BaseState") -> _InputType:
|
32
|
-
if hasattr(state.meta.workflow_inputs, self._name)
|
33
|
+
if hasattr(state.meta.workflow_inputs, self._name) and (
|
34
|
+
state.meta.workflow_definition == self._inputs_class.__parent_class__
|
35
|
+
or not issubclass(self._inputs_class.__parent_class__, import_workflow_class())
|
36
|
+
):
|
33
37
|
return cast(_InputType, getattr(state.meta.workflow_inputs, self._name))
|
34
38
|
|
35
39
|
if state.meta.parent:
|
@@ -101,6 +101,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
101
101
|
if state:
|
102
102
|
self._initial_state = deepcopy(state)
|
103
103
|
self._initial_state.meta.span_id = uuid4()
|
104
|
+
self._initial_state.meta.workflow_definition = self.workflow.__class__
|
104
105
|
else:
|
105
106
|
self._initial_state = self.workflow.get_state_at_node(node)
|
106
107
|
self._entrypoints = entrypoint_nodes
|
@@ -126,6 +127,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
126
127
|
self._initial_state = deepcopy(state)
|
127
128
|
self._initial_state.meta.workflow_inputs = normalized_inputs
|
128
129
|
self._initial_state.meta.span_id = uuid4()
|
130
|
+
self._initial_state.meta.workflow_definition = self.workflow.__class__
|
129
131
|
else:
|
130
132
|
self._initial_state = self.workflow.get_default_state(normalized_inputs)
|
131
133
|
# We don't want to emit the initial state on the base case of Workflow Runs, since
|
@@ -133,6 +133,11 @@ class _BaseWorkflowMeta(type):
|
|
133
133
|
cls = super().__new__(mcs, name, bases, dct)
|
134
134
|
workflow_class = cast(Type["BaseWorkflow"], cls)
|
135
135
|
workflow_class.__id__ = uuid4_from_hash(workflow_class.__qualname__)
|
136
|
+
|
137
|
+
inputs_class = workflow_class.get_inputs_class()
|
138
|
+
if inputs_class is not BaseInputs and inputs_class.__parent_class__ is type(None):
|
139
|
+
inputs_class.__parent_class__ = workflow_class
|
140
|
+
|
136
141
|
return workflow_class
|
137
142
|
|
138
143
|
|