vellum-ai 0.14.38__py3-none-any.whl → 0.14.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +2 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +2 -0
- vellum/client/types/test_suite_run_progress.py +20 -0
- vellum/client/types/test_suite_run_read.py +3 -0
- vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
- vellum/client/types/workflow_execution_event_error_code.py +1 -0
- vellum/types/test_suite_run_progress.py +3 -0
- vellum/workflows/errors/types.py +1 -0
- vellum/workflows/events/tests/test_event.py +1 -0
- vellum/workflows/events/workflow.py +13 -3
- vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
- vellum/workflows/nodes/core/try_node/node.py +1 -2
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +26 -0
- vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
- vellum/workflows/nodes/utils.py +4 -2
- vellum/workflows/outputs/base.py +3 -2
- vellum/workflows/references/output.py +20 -0
- vellum/workflows/runner/runner.py +37 -17
- vellum/workflows/state/base.py +64 -19
- vellum/workflows/state/tests/test_state.py +31 -22
- vellum/workflows/types/stack.py +11 -0
- vellum/workflows/workflows/base.py +13 -18
- vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +82 -75
- vellum_cli/push.py +2 -5
- vellum_cli/tests/test_push.py +52 -0
- vellum_ee/workflows/display/base.py +14 -1
- vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
- vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
- vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
- vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
- vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
- vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
- vellum_ee/workflows/display/types.py +5 -4
- vellum_ee/workflows/display/utils/exceptions.py +7 -0
- vellum_ee/workflows/display/utils/registry.py +37 -0
- vellum_ee/workflows/display/utils/vellum.py +2 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
- vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
- vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
vellum/__init__.py
CHANGED
@@ -443,6 +443,7 @@ from .types import (
|
|
443
443
|
TestSuiteRunMetricNumberOutput,
|
444
444
|
TestSuiteRunMetricOutput,
|
445
445
|
TestSuiteRunMetricStringOutput,
|
446
|
+
TestSuiteRunProgress,
|
446
447
|
TestSuiteRunPromptSandboxExecConfigDataRequest,
|
447
448
|
TestSuiteRunPromptSandboxExecConfigRequest,
|
448
449
|
TestSuiteRunPromptSandboxHistoryItemExecConfig,
|
@@ -1072,6 +1073,7 @@ __all__ = [
|
|
1072
1073
|
"TestSuiteRunMetricNumberOutput",
|
1073
1074
|
"TestSuiteRunMetricOutput",
|
1074
1075
|
"TestSuiteRunMetricStringOutput",
|
1076
|
+
"TestSuiteRunProgress",
|
1075
1077
|
"TestSuiteRunPromptSandboxExecConfigDataRequest",
|
1076
1078
|
"TestSuiteRunPromptSandboxExecConfigRequest",
|
1077
1079
|
"TestSuiteRunPromptSandboxHistoryItemExecConfig",
|
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.40",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
vellum/client/types/__init__.py
CHANGED
@@ -453,6 +453,7 @@ from .test_suite_run_metric_json_output import TestSuiteRunMetricJsonOutput
|
|
453
453
|
from .test_suite_run_metric_number_output import TestSuiteRunMetricNumberOutput
|
454
454
|
from .test_suite_run_metric_output import TestSuiteRunMetricOutput
|
455
455
|
from .test_suite_run_metric_string_output import TestSuiteRunMetricStringOutput
|
456
|
+
from .test_suite_run_progress import TestSuiteRunProgress
|
456
457
|
from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
|
457
458
|
from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
|
458
459
|
from .test_suite_run_prompt_sandbox_history_item_exec_config import TestSuiteRunPromptSandboxHistoryItemExecConfig
|
@@ -1052,6 +1053,7 @@ __all__ = [
|
|
1052
1053
|
"TestSuiteRunMetricNumberOutput",
|
1053
1054
|
"TestSuiteRunMetricOutput",
|
1054
1055
|
"TestSuiteRunMetricStringOutput",
|
1056
|
+
"TestSuiteRunProgress",
|
1055
1057
|
"TestSuiteRunPromptSandboxExecConfigDataRequest",
|
1056
1058
|
"TestSuiteRunPromptSandboxExecConfigRequest",
|
1057
1059
|
"TestSuiteRunPromptSandboxHistoryItemExecConfig",
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
5
|
+
import typing
|
6
|
+
import pydantic
|
7
|
+
|
8
|
+
|
9
|
+
class TestSuiteRunProgress(UniversalBaseModel):
|
10
|
+
number_of_requested_test_cases: int
|
11
|
+
number_of_completed_test_cases: int
|
12
|
+
|
13
|
+
if IS_PYDANTIC_V2:
|
14
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
15
|
+
else:
|
16
|
+
|
17
|
+
class Config:
|
18
|
+
frozen = True
|
19
|
+
smart_union = True
|
20
|
+
extra = pydantic.Extra.allow
|
@@ -9,6 +9,7 @@ from .test_suite_run_state import TestSuiteRunState
|
|
9
9
|
import pydantic
|
10
10
|
import typing
|
11
11
|
from .test_suite_run_exec_config import TestSuiteRunExecConfig
|
12
|
+
from .test_suite_run_progress import TestSuiteRunProgress
|
12
13
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
13
14
|
from ..core.pydantic_utilities import update_forward_refs
|
14
15
|
|
@@ -33,6 +34,8 @@ class TestSuiteRunRead(UniversalBaseModel):
|
|
33
34
|
Configuration that defines how the Test Suite should be run
|
34
35
|
"""
|
35
36
|
|
37
|
+
progress: typing.Optional[TestSuiteRunProgress] = None
|
38
|
+
|
36
39
|
if IS_PYDANTIC_V2:
|
37
40
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
38
41
|
else:
|
vellum/workflows/errors/types.py
CHANGED
@@ -17,6 +17,7 @@ class WorkflowErrorCode(Enum):
|
|
17
17
|
INVALID_TEMPLATE = "INVALID_TEMPLATE"
|
18
18
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
19
19
|
NODE_EXECUTION = "NODE_EXECUTION"
|
20
|
+
PROVIDER_CREDENTIALS_UNAVAILABLE = "PROVIDER_CREDENTIALS_UNAVAILABLE"
|
20
21
|
PROVIDER_ERROR = "PROVIDER_ERROR"
|
21
22
|
USER_DEFINED_ERROR = "USER_DEFINED_ERROR"
|
22
23
|
WORKFLOW_CANCELLED = "WORKFLOW_CANCELLED"
|
@@ -54,8 +54,10 @@ class WorkflowEventDisplayContext(UniversalBaseModel):
|
|
54
54
|
workflow_outputs: Dict[str, UUID]
|
55
55
|
|
56
56
|
|
57
|
-
class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsType]):
|
57
|
+
class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsType, StateType]):
|
58
58
|
inputs: InputsType
|
59
|
+
initial_state: Optional[StateType] = None
|
60
|
+
|
59
61
|
# It is still the responsibility of the workflow server to populate this context. The SDK's
|
60
62
|
# Workflow Runner will always leave this field None.
|
61
63
|
#
|
@@ -67,15 +69,23 @@ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsT
|
|
67
69
|
def serialize_inputs(self, inputs: InputsType, _info: Any) -> Dict[str, Any]:
|
68
70
|
return default_serializer(inputs)
|
69
71
|
|
72
|
+
@field_serializer("initial_state")
|
73
|
+
def serialize_initial_state(self, initial_state: Optional[StateType], _info: Any) -> Optional[Dict[str, Any]]:
|
74
|
+
return default_serializer(initial_state)
|
75
|
+
|
70
76
|
|
71
|
-
class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType]):
|
77
|
+
class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType, StateType]):
|
72
78
|
name: Literal["workflow.execution.initiated"] = "workflow.execution.initiated"
|
73
|
-
body: WorkflowExecutionInitiatedBody[InputsType]
|
79
|
+
body: WorkflowExecutionInitiatedBody[InputsType, StateType]
|
74
80
|
|
75
81
|
@property
|
76
82
|
def inputs(self) -> InputsType:
|
77
83
|
return self.body.inputs
|
78
84
|
|
85
|
+
@property
|
86
|
+
def initial_state(self) -> Optional[StateType]:
|
87
|
+
return self.body.initial_state
|
88
|
+
|
79
89
|
|
80
90
|
class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
|
81
91
|
output: BaseOutput
|
@@ -4,11 +4,13 @@ from typing import Optional
|
|
4
4
|
|
5
5
|
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
6
6
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
7
|
+
from vellum.workflows.constants import undefined
|
7
8
|
from vellum.workflows.descriptors.tests.test_utils import FixtureState
|
8
9
|
from vellum.workflows.inputs.base import BaseInputs
|
9
10
|
from vellum.workflows.nodes import FinalOutputNode
|
10
11
|
from vellum.workflows.nodes.bases.base import BaseNode
|
11
12
|
from vellum.workflows.outputs.base import BaseOutputs
|
13
|
+
from vellum.workflows.references.output import OutputReference
|
12
14
|
from vellum.workflows.state.base import BaseState, StateMeta
|
13
15
|
|
14
16
|
|
@@ -259,3 +261,25 @@ def test_resolve_value__for_falsy_values(falsy_value, expected_type):
|
|
259
261
|
|
260
262
|
# THEN the output has the correct value
|
261
263
|
assert falsy_output.value == falsy_value
|
264
|
+
|
265
|
+
|
266
|
+
def test_node_outputs__inherits_instance():
|
267
|
+
# GIVEN a node with two outputs, one with and one without a default instance
|
268
|
+
class MyNode(BaseNode):
|
269
|
+
class Outputs:
|
270
|
+
foo: str
|
271
|
+
bar = "hello"
|
272
|
+
|
273
|
+
# AND a node that inherits from MyNode
|
274
|
+
class InheritedNode(MyNode):
|
275
|
+
pass
|
276
|
+
|
277
|
+
# WHEN we reference each output
|
278
|
+
foo_output = InheritedNode.Outputs.foo
|
279
|
+
bar_output = InheritedNode.Outputs.bar
|
280
|
+
|
281
|
+
# THEN the output reference instances are correct
|
282
|
+
assert isinstance(foo_output, OutputReference)
|
283
|
+
assert foo_output.instance is undefined
|
284
|
+
assert isinstance(bar_output, OutputReference)
|
285
|
+
assert bar_output.instance == "hello"
|
@@ -4,7 +4,6 @@ from vellum.workflows.context import execution_context, get_parent_context
|
|
4
4
|
from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
|
5
5
|
from vellum.workflows.events.workflow import is_workflow_event
|
6
6
|
from vellum.workflows.exceptions import NodeException
|
7
|
-
from vellum.workflows.nodes.bases import BaseNode
|
8
7
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
9
8
|
from vellum.workflows.nodes.utils import create_adornment
|
10
9
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
@@ -24,7 +23,7 @@ class TryNode(BaseAdornmentNode[StateType], Generic[StateType]):
|
|
24
23
|
|
25
24
|
on_error_code: Optional[WorkflowErrorCode] = None
|
26
25
|
|
27
|
-
class Outputs(
|
26
|
+
class Outputs(BaseAdornmentNode.Outputs):
|
28
27
|
error: Optional[WorkflowError] = None
|
29
28
|
|
30
29
|
def run(self) -> Iterator[BaseOutput]:
|
@@ -69,7 +69,13 @@ class BasePromptNode(BaseNode, Generic[StateType]):
|
|
69
69
|
return outputs
|
70
70
|
|
71
71
|
def _handle_api_error(self, e: ApiError):
|
72
|
-
if e.status_code and e.status_code
|
72
|
+
if e.status_code and e.status_code == 403 and isinstance(e.body, dict):
|
73
|
+
raise NodeException(
|
74
|
+
message=e.body.get("detail", "Provider credentials is missing or unavailable"),
|
75
|
+
code=WorkflowErrorCode.PROVIDER_CREDENTIALS_UNAVAILABLE,
|
76
|
+
)
|
77
|
+
|
78
|
+
elif e.status_code and e.status_code >= 400 and e.status_code < 500 and isinstance(e.body, dict):
|
73
79
|
raise NodeException(
|
74
80
|
message=e.body.get("detail", "Failed to execute Prompt"),
|
75
81
|
code=WorkflowErrorCode.INVALID_INPUTS,
|
@@ -170,8 +170,13 @@ def test_inline_prompt_node__function_definitions(vellum_adhoc_prompt_client):
|
|
170
170
|
WorkflowErrorCode.INTERNAL_ERROR,
|
171
171
|
"Failed to execute Prompt",
|
172
172
|
),
|
173
|
+
(
|
174
|
+
ApiError(status_code=403, body={"detail": "Provider credentials is missing or unavailable"}),
|
175
|
+
WorkflowErrorCode.PROVIDER_CREDENTIALS_UNAVAILABLE,
|
176
|
+
"Provider credentials is missing or unavailable",
|
177
|
+
),
|
173
178
|
],
|
174
|
-
ids=["404", "invalid_dict", "invalid_body", "no_status_code", "500"],
|
179
|
+
ids=["404", "invalid_dict", "invalid_body", "no_status_code", "500", "403"],
|
175
180
|
)
|
176
181
|
def test_inline_prompt_node__api_error__invalid_inputs_node_exception(
|
177
182
|
vellum_adhoc_prompt_client, exception, expected_code, expected_message
|
@@ -491,3 +491,29 @@ def test_prompt_deployment_node__no_fallbacks(vellum_client):
|
|
491
491
|
|
492
492
|
# AND the client should have been called only once (for the primary model)
|
493
493
|
assert vellum_client.execute_prompt_stream.call_count == 1
|
494
|
+
|
495
|
+
|
496
|
+
def test_prompt_deployment_node__provider_credentials_missing(vellum_client):
|
497
|
+
# GIVEN a Prompt Deployment Node
|
498
|
+
class TestPromptDeploymentNode(PromptDeploymentNode):
|
499
|
+
deployment = "test_deployment"
|
500
|
+
prompt_inputs = {}
|
501
|
+
|
502
|
+
# AND the client responds with a 403 error of provider credentials missing
|
503
|
+
primary_error = ApiError(
|
504
|
+
body={"detail": "Provider credentials is missing or unavailable"},
|
505
|
+
status_code=403,
|
506
|
+
)
|
507
|
+
|
508
|
+
vellum_client.execute_prompt_stream.side_effect = primary_error
|
509
|
+
|
510
|
+
# WHEN we run the node
|
511
|
+
node = TestPromptDeploymentNode()
|
512
|
+
|
513
|
+
# THEN the node should raise an exception
|
514
|
+
with pytest.raises(NodeException) as exc_info:
|
515
|
+
list(node.run())
|
516
|
+
|
517
|
+
# AND the exception should contain the original error message
|
518
|
+
assert exc_info.value.message == "Provider credentials is missing or unavailable"
|
519
|
+
assert exc_info.value.code == WorkflowErrorCode.PROVIDER_CREDENTIALS_UNAVAILABLE
|
@@ -0,0 +1,147 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
from typing import Any, ClassVar, Dict, List, Optional, cast
|
3
|
+
|
4
|
+
from vellum import ChatMessage, FunctionDefinition, PromptBlock
|
5
|
+
from vellum.client.types.chat_message_request import ChatMessageRequest
|
6
|
+
from vellum.workflows.context import execution_context, get_parent_context
|
7
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
8
|
+
from vellum.workflows.exceptions import NodeException
|
9
|
+
from vellum.workflows.graph.graph import Graph
|
10
|
+
from vellum.workflows.inputs.base import BaseInputs
|
11
|
+
from vellum.workflows.nodes.bases import BaseNode
|
12
|
+
from vellum.workflows.nodes.experimental.tool_calling_node.utils import (
|
13
|
+
ToolRouterNode,
|
14
|
+
create_function_node,
|
15
|
+
create_tool_router_node,
|
16
|
+
)
|
17
|
+
from vellum.workflows.outputs.base import BaseOutputs
|
18
|
+
from vellum.workflows.state.base import BaseState
|
19
|
+
from vellum.workflows.state.context import WorkflowContext
|
20
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
21
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
22
|
+
|
23
|
+
|
24
|
+
class ToolCallingNode(BaseNode):
|
25
|
+
"""
|
26
|
+
A Node that dynamically invokes the provided functions to the underlying Prompt
|
27
|
+
|
28
|
+
Attributes:
|
29
|
+
ml_model: str - The model to use for tool calling (e.g., "gpt-4o-mini")
|
30
|
+
blocks: List[PromptBlock] - The prompt blocks to use (same format as InlinePromptNode)
|
31
|
+
functions: List[FunctionDefinition] - The functions that can be called
|
32
|
+
function_callables: List[Callable] - The callables that can be called
|
33
|
+
prompt_inputs: Optional[EntityInputsInterface] - Mapping of input variable names to values
|
34
|
+
"""
|
35
|
+
|
36
|
+
ml_model: ClassVar[str] = "gpt-4o-mini"
|
37
|
+
blocks: ClassVar[List[PromptBlock]] = []
|
38
|
+
functions: ClassVar[List[FunctionDefinition]] = []
|
39
|
+
function_callables: ClassVar[Dict[str, Callable[..., Any]]] = {}
|
40
|
+
prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
|
41
|
+
# TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
|
42
|
+
max_tool_calls: ClassVar[int] = 1
|
43
|
+
|
44
|
+
class Outputs(BaseOutputs):
|
45
|
+
"""
|
46
|
+
The outputs of the ToolCallingNode.
|
47
|
+
|
48
|
+
text: The final text response after tool calling
|
49
|
+
chat_history: The complete chat history including tool calls
|
50
|
+
"""
|
51
|
+
|
52
|
+
text: str = ""
|
53
|
+
chat_history: List[ChatMessage] = []
|
54
|
+
|
55
|
+
def run(self) -> Outputs:
|
56
|
+
"""
|
57
|
+
Run the tool calling workflow.
|
58
|
+
|
59
|
+
This dynamically builds a graph with router and function nodes,
|
60
|
+
then executes the workflow.
|
61
|
+
"""
|
62
|
+
self._validate_functions()
|
63
|
+
|
64
|
+
initial_chat_history = []
|
65
|
+
|
66
|
+
# Extract chat history from prompt inputs if available
|
67
|
+
if self.prompt_inputs and "chat_history" in self.prompt_inputs:
|
68
|
+
chat_history_input = self.prompt_inputs["chat_history"]
|
69
|
+
if isinstance(chat_history_input, list) and all(
|
70
|
+
isinstance(msg, (ChatMessage, ChatMessageRequest)) for msg in chat_history_input
|
71
|
+
):
|
72
|
+
initial_chat_history = [
|
73
|
+
msg if isinstance(msg, ChatMessage) else ChatMessage.model_validate(msg.model_dump())
|
74
|
+
for msg in chat_history_input
|
75
|
+
]
|
76
|
+
|
77
|
+
self._build_graph()
|
78
|
+
|
79
|
+
with execution_context(parent_context=get_parent_context()):
|
80
|
+
|
81
|
+
class ToolCallingState(BaseState):
|
82
|
+
chat_history: List[ChatMessage] = initial_chat_history
|
83
|
+
|
84
|
+
class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
|
85
|
+
graph = self._graph
|
86
|
+
|
87
|
+
class Outputs(BaseWorkflow.Outputs):
|
88
|
+
text: str = ToolRouterNode.Outputs.text
|
89
|
+
chat_history: List[ChatMessage] = ToolCallingState.chat_history
|
90
|
+
|
91
|
+
subworkflow = ToolCallingWorkflow(
|
92
|
+
parent_state=self.state,
|
93
|
+
context=WorkflowContext.create_from(self._context),
|
94
|
+
)
|
95
|
+
|
96
|
+
terminal_event = subworkflow.run()
|
97
|
+
|
98
|
+
if terminal_event.name == "workflow.execution.paused":
|
99
|
+
raise NodeException(
|
100
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
101
|
+
message="Subworkflow unexpectedly paused",
|
102
|
+
)
|
103
|
+
elif terminal_event.name == "workflow.execution.fulfilled":
|
104
|
+
node_outputs = self.Outputs()
|
105
|
+
|
106
|
+
for output_descriptor, output_value in terminal_event.outputs:
|
107
|
+
setattr(node_outputs, output_descriptor.name, output_value)
|
108
|
+
|
109
|
+
return node_outputs
|
110
|
+
elif terminal_event.name == "workflow.execution.rejected":
|
111
|
+
raise Exception(f"Workflow execution rejected: {terminal_event.error}")
|
112
|
+
|
113
|
+
raise Exception(f"Unexpected workflow event: {terminal_event.name}")
|
114
|
+
|
115
|
+
def _build_graph(self) -> None:
|
116
|
+
self.tool_router_node = create_tool_router_node(
|
117
|
+
ml_model=self.ml_model,
|
118
|
+
blocks=self.blocks,
|
119
|
+
functions=self.functions,
|
120
|
+
prompt_inputs=self.prompt_inputs,
|
121
|
+
)
|
122
|
+
|
123
|
+
self._function_nodes = {
|
124
|
+
function.name: create_function_node(
|
125
|
+
function=function,
|
126
|
+
function_callable=cast(Callable[..., Any], self.function_callables[function.name]), # type: ignore
|
127
|
+
)
|
128
|
+
for function in self.functions
|
129
|
+
}
|
130
|
+
|
131
|
+
graph_set = set()
|
132
|
+
|
133
|
+
# Add connections from ports of router to function nodes and back to router
|
134
|
+
for function_name, FunctionNodeClass in self._function_nodes.items():
|
135
|
+
router_port = getattr(self.tool_router_node.Ports, function_name) # type: ignore # mypy thinks name is still optional # noqa: E501
|
136
|
+
edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
|
137
|
+
graph_set.add(edge_graph)
|
138
|
+
|
139
|
+
default_port = getattr(self.tool_router_node.Ports, "default")
|
140
|
+
graph_set.add(default_port)
|
141
|
+
|
142
|
+
self._graph = Graph.from_set(graph_set)
|
143
|
+
|
144
|
+
def _validate_functions(self) -> None:
|
145
|
+
for function in self.functions:
|
146
|
+
if function.name is None:
|
147
|
+
raise ValueError("Function name is required")
|
@@ -0,0 +1,132 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
import json
|
3
|
+
from typing import Any, Iterator, List, Optional, Type, cast
|
4
|
+
|
5
|
+
from vellum import ChatMessage, FunctionDefinition, PromptBlock
|
6
|
+
from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
|
7
|
+
from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
|
8
|
+
from vellum.client.types.variable_prompt_block import VariablePromptBlock
|
9
|
+
from vellum.workflows.nodes.bases import BaseNode
|
10
|
+
from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
|
11
|
+
from vellum.workflows.outputs.base import BaseOutput
|
12
|
+
from vellum.workflows.ports.port import Port
|
13
|
+
from vellum.workflows.references.lazy import LazyReference
|
14
|
+
from vellum.workflows.types.core import EntityInputsInterface
|
15
|
+
|
16
|
+
|
17
|
+
class FunctionNode(BaseNode):
|
18
|
+
"""Node that executes a specific function."""
|
19
|
+
|
20
|
+
function: FunctionDefinition
|
21
|
+
|
22
|
+
|
23
|
+
class ToolRouterNode(InlinePromptNode):
|
24
|
+
def run(self) -> Iterator[BaseOutput]:
|
25
|
+
self.prompt_inputs = {**self.prompt_inputs, "chat_history": self.state.chat_history} # type: ignore
|
26
|
+
generator = super().run()
|
27
|
+
for output in generator:
|
28
|
+
if output.name == "results" and output.value:
|
29
|
+
values = cast(List[Any], output.value)
|
30
|
+
if values and len(values) > 0:
|
31
|
+
if values[0].type == "STRING":
|
32
|
+
self.state.chat_history.append(ChatMessage(role="ASSISTANT", text=values[0].value))
|
33
|
+
elif values[0].type == "FUNCTION_CALL":
|
34
|
+
function_call = values[0].value
|
35
|
+
if function_call is not None:
|
36
|
+
self.state.chat_history.append(
|
37
|
+
ChatMessage(
|
38
|
+
role="FUNCTION",
|
39
|
+
content=FunctionCallChatMessageContent(
|
40
|
+
value=FunctionCallChatMessageContentValue(
|
41
|
+
name=function_call.name,
|
42
|
+
arguments=function_call.arguments,
|
43
|
+
id=function_call.id,
|
44
|
+
),
|
45
|
+
),
|
46
|
+
)
|
47
|
+
)
|
48
|
+
yield output
|
49
|
+
|
50
|
+
|
51
|
+
def create_tool_router_node(
|
52
|
+
ml_model: str,
|
53
|
+
blocks: List[PromptBlock],
|
54
|
+
functions: List[FunctionDefinition],
|
55
|
+
prompt_inputs: Optional[EntityInputsInterface],
|
56
|
+
) -> Type[ToolRouterNode]:
|
57
|
+
Ports = type("Ports", (), {})
|
58
|
+
for function in functions:
|
59
|
+
if function.name is None:
|
60
|
+
# We should not raise an error here since we filter out functions without names
|
61
|
+
raise ValueError("Function name is required")
|
62
|
+
|
63
|
+
function_name = function.name
|
64
|
+
port_condition = LazyReference(
|
65
|
+
lambda: (
|
66
|
+
ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
|
67
|
+
& ToolRouterNode.Outputs.results[0]["value"]["name"].equals(function_name)
|
68
|
+
)
|
69
|
+
)
|
70
|
+
port = Port.on_if(port_condition)
|
71
|
+
setattr(Ports, function_name, port)
|
72
|
+
|
73
|
+
setattr(Ports, "default", Port.on_else())
|
74
|
+
|
75
|
+
# Add a chat history block to blocks
|
76
|
+
blocks.append(
|
77
|
+
VariablePromptBlock(
|
78
|
+
block_type="VARIABLE",
|
79
|
+
input_variable="chat_history",
|
80
|
+
state=None,
|
81
|
+
cache_config=None,
|
82
|
+
)
|
83
|
+
)
|
84
|
+
|
85
|
+
node = type(
|
86
|
+
"ToolRouterNode",
|
87
|
+
(ToolRouterNode,),
|
88
|
+
{
|
89
|
+
"ml_model": ml_model,
|
90
|
+
"blocks": blocks,
|
91
|
+
"functions": functions,
|
92
|
+
"prompt_inputs": prompt_inputs,
|
93
|
+
"Ports": Ports,
|
94
|
+
"__module__": __name__,
|
95
|
+
},
|
96
|
+
)
|
97
|
+
|
98
|
+
return node
|
99
|
+
|
100
|
+
|
101
|
+
def create_function_node(function: FunctionDefinition, function_callable: Callable[..., Any]) -> Type[FunctionNode]:
|
102
|
+
"""
|
103
|
+
Create a FunctionNode class for a given function.
|
104
|
+
|
105
|
+
This ensures the callable is properly registered and can be called with the expected arguments.
|
106
|
+
"""
|
107
|
+
|
108
|
+
# Create a class-level wrapper that calls the original function
|
109
|
+
def execute_function(self) -> BaseNode.Outputs:
|
110
|
+
outputs = self.state.meta.node_outputs.get(ToolRouterNode.Outputs.text)
|
111
|
+
# first parse into json
|
112
|
+
outputs = json.loads(outputs)
|
113
|
+
arguments = outputs["arguments"]
|
114
|
+
|
115
|
+
# Call the original function directly with the arguments
|
116
|
+
result = function_callable(**arguments)
|
117
|
+
|
118
|
+
self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
|
119
|
+
|
120
|
+
return self.Outputs()
|
121
|
+
|
122
|
+
node = type(
|
123
|
+
f"FunctionNode_{function.name}",
|
124
|
+
(FunctionNode,),
|
125
|
+
{
|
126
|
+
"function": function,
|
127
|
+
"run": execute_function,
|
128
|
+
"__module__": __name__,
|
129
|
+
},
|
130
|
+
)
|
131
|
+
|
132
|
+
return node
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -57,10 +57,12 @@ def create_adornment(
|
|
57
57
|
class Subworkflow(BaseWorkflow):
|
58
58
|
graph = inner_cls
|
59
59
|
|
60
|
-
|
61
|
-
class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
|
60
|
+
class Outputs(BaseWorkflow.Outputs):
|
62
61
|
pass
|
63
62
|
|
63
|
+
for output_reference in inner_cls.Outputs:
|
64
|
+
setattr(Subworkflow.Outputs, output_reference.name, output_reference)
|
65
|
+
|
64
66
|
dynamic_module = f"{inner_cls.__module__}.{inner_cls.__name__}.{ADORNMENT_MODULE_NAME}"
|
65
67
|
# This dynamic module allows calls to `type_hints` to work
|
66
68
|
sys.modules[dynamic_module] = ModuleType(dynamic_module)
|
vellum/workflows/outputs/base.py
CHANGED
@@ -147,8 +147,9 @@ class _BaseOutputsMeta(type):
|
|
147
147
|
instance = vars(cls).get(name, undefined)
|
148
148
|
if instance is undefined:
|
149
149
|
for base in cls.__mro__[1:]:
|
150
|
-
|
151
|
-
|
150
|
+
inherited_output_reference = getattr(base, name, undefined)
|
151
|
+
if isinstance(inherited_output_reference, OutputReference):
|
152
|
+
instance = inherited_output_reference.instance
|
152
153
|
break
|
153
154
|
|
154
155
|
types = infer_types(cls, name)
|
@@ -1,4 +1,6 @@
|
|
1
|
+
from functools import cached_property
|
1
2
|
from queue import Queue
|
3
|
+
from uuid import UUID, uuid4
|
2
4
|
from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Tuple, Type, TypeVar, cast
|
3
5
|
|
4
6
|
from pydantic import GetCoreSchemaHandler
|
@@ -31,6 +33,24 @@ class OutputReference(BaseDescriptor[_OutputType], Generic[_OutputType]):
|
|
31
33
|
def outputs_class(self) -> Type["BaseOutputs"]:
|
32
34
|
return self._outputs_class
|
33
35
|
|
36
|
+
@cached_property
|
37
|
+
def id(self) -> UUID:
|
38
|
+
self._outputs_class = self._outputs_class
|
39
|
+
|
40
|
+
node_class = getattr(self._outputs_class, "_node_class", None)
|
41
|
+
if not node_class:
|
42
|
+
return uuid4()
|
43
|
+
|
44
|
+
output_ids = getattr(node_class, "__output_ids__", {})
|
45
|
+
if not isinstance(output_ids, dict):
|
46
|
+
return uuid4()
|
47
|
+
|
48
|
+
output_id = output_ids.get(self.name)
|
49
|
+
if not isinstance(output_id, UUID):
|
50
|
+
return uuid4()
|
51
|
+
|
52
|
+
return output_id
|
53
|
+
|
34
54
|
def resolve(self, state: "BaseState") -> _OutputType:
|
35
55
|
node_output = state.meta.node_outputs.get(self, undefined)
|
36
56
|
if isinstance(node_output, Queue):
|