vellum-ai 0.14.38__py3-none-any.whl → 0.14.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. vellum/__init__.py +2 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +2 -0
  4. vellum/client/types/test_suite_run_progress.py +20 -0
  5. vellum/client/types/test_suite_run_read.py +3 -0
  6. vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
  7. vellum/client/types/workflow_execution_event_error_code.py +1 -0
  8. vellum/types/test_suite_run_progress.py +3 -0
  9. vellum/workflows/errors/types.py +1 -0
  10. vellum/workflows/events/tests/test_event.py +1 -0
  11. vellum/workflows/events/workflow.py +13 -3
  12. vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
  13. vellum/workflows/nodes/core/try_node/node.py +1 -2
  14. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
  15. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
  16. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +26 -0
  17. vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
  18. vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
  19. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
  20. vellum/workflows/nodes/utils.py +4 -2
  21. vellum/workflows/outputs/base.py +3 -2
  22. vellum/workflows/references/output.py +20 -0
  23. vellum/workflows/runner/runner.py +37 -17
  24. vellum/workflows/state/base.py +64 -19
  25. vellum/workflows/state/tests/test_state.py +31 -22
  26. vellum/workflows/types/stack.py +11 -0
  27. vellum/workflows/workflows/base.py +13 -18
  28. vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
  29. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
  30. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +82 -75
  31. vellum_cli/push.py +2 -5
  32. vellum_cli/tests/test_push.py +52 -0
  33. vellum_ee/workflows/display/base.py +14 -1
  34. vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
  35. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
  36. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
  37. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
  38. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
  39. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
  40. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
  41. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
  42. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
  43. vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
  45. vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
  46. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  47. vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
  48. vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
  49. vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
  50. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
  51. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
  52. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
  64. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
  65. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
  66. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
  70. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
  71. vellum_ee/workflows/display/types.py +5 -4
  72. vellum_ee/workflows/display/utils/exceptions.py +7 -0
  73. vellum_ee/workflows/display/utils/registry.py +37 -0
  74. vellum_ee/workflows/display/utils/vellum.py +2 -1
  75. vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
  76. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
  77. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
  78. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
  79. vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
  80. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
  81. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
  82. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
vellum/__init__.py CHANGED
@@ -443,6 +443,7 @@ from .types import (
443
443
  TestSuiteRunMetricNumberOutput,
444
444
  TestSuiteRunMetricOutput,
445
445
  TestSuiteRunMetricStringOutput,
446
+ TestSuiteRunProgress,
446
447
  TestSuiteRunPromptSandboxExecConfigDataRequest,
447
448
  TestSuiteRunPromptSandboxExecConfigRequest,
448
449
  TestSuiteRunPromptSandboxHistoryItemExecConfig,
@@ -1072,6 +1073,7 @@ __all__ = [
1072
1073
  "TestSuiteRunMetricNumberOutput",
1073
1074
  "TestSuiteRunMetricOutput",
1074
1075
  "TestSuiteRunMetricStringOutput",
1076
+ "TestSuiteRunProgress",
1075
1077
  "TestSuiteRunPromptSandboxExecConfigDataRequest",
1076
1078
  "TestSuiteRunPromptSandboxExecConfigRequest",
1077
1079
  "TestSuiteRunPromptSandboxHistoryItemExecConfig",
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.38",
21
+ "X-Fern-SDK-Version": "0.14.40",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -453,6 +453,7 @@ from .test_suite_run_metric_json_output import TestSuiteRunMetricJsonOutput
453
453
  from .test_suite_run_metric_number_output import TestSuiteRunMetricNumberOutput
454
454
  from .test_suite_run_metric_output import TestSuiteRunMetricOutput
455
455
  from .test_suite_run_metric_string_output import TestSuiteRunMetricStringOutput
456
+ from .test_suite_run_progress import TestSuiteRunProgress
456
457
  from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
457
458
  from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
458
459
  from .test_suite_run_prompt_sandbox_history_item_exec_config import TestSuiteRunPromptSandboxHistoryItemExecConfig
@@ -1052,6 +1053,7 @@ __all__ = [
1052
1053
  "TestSuiteRunMetricNumberOutput",
1053
1054
  "TestSuiteRunMetricOutput",
1054
1055
  "TestSuiteRunMetricStringOutput",
1056
+ "TestSuiteRunProgress",
1055
1057
  "TestSuiteRunPromptSandboxExecConfigDataRequest",
1056
1058
  "TestSuiteRunPromptSandboxExecConfigRequest",
1057
1059
  "TestSuiteRunPromptSandboxHistoryItemExecConfig",
@@ -0,0 +1,20 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
5
+ import typing
6
+ import pydantic
7
+
8
+
9
+ class TestSuiteRunProgress(UniversalBaseModel):
10
+ number_of_requested_test_cases: int
11
+ number_of_completed_test_cases: int
12
+
13
+ if IS_PYDANTIC_V2:
14
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
15
+ else:
16
+
17
+ class Config:
18
+ frozen = True
19
+ smart_union = True
20
+ extra = pydantic.Extra.allow
@@ -9,6 +9,7 @@ from .test_suite_run_state import TestSuiteRunState
9
9
  import pydantic
10
10
  import typing
11
11
  from .test_suite_run_exec_config import TestSuiteRunExecConfig
12
+ from .test_suite_run_progress import TestSuiteRunProgress
12
13
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
13
14
  from ..core.pydantic_utilities import update_forward_refs
14
15
 
@@ -33,6 +34,8 @@ class TestSuiteRunRead(UniversalBaseModel):
33
34
  Configuration that defines how the Test Suite should be run
34
35
  """
35
36
 
37
+ progress: typing.Optional[TestSuiteRunProgress] = None
38
+
36
39
  if IS_PYDANTIC_V2:
37
40
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
38
41
  else:
@@ -11,6 +11,7 @@ VellumSdkErrorCodeEnum = typing.Union[
11
11
  "INVALID_CODE",
12
12
  "INVALID_TEMPLATE",
13
13
  "INTERNAL_ERROR",
14
+ "PROVIDER_CREDENTIALS_UNAVAILABLE",
14
15
  "PROVIDER_ERROR",
15
16
  "USER_DEFINED_ERROR",
16
17
  "WORKFLOW_CANCELLED",
@@ -6,6 +6,7 @@ WorkflowExecutionEventErrorCode = typing.Union[
6
6
  typing.Literal[
7
7
  "WORKFLOW_INITIALIZATION",
8
8
  "WORKFLOW_CANCELLED",
9
+ "PROVIDER_CREDENTIALS_UNAVAILABLE",
9
10
  "NODE_EXECUTION_COUNT_LIMIT_REACHED",
10
11
  "INTERNAL_SERVER_ERROR",
11
12
  "NODE_EXECUTION",
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_progress import *
@@ -17,6 +17,7 @@ class WorkflowErrorCode(Enum):
17
17
  INVALID_TEMPLATE = "INVALID_TEMPLATE"
18
18
  INTERNAL_ERROR = "INTERNAL_ERROR"
19
19
  NODE_EXECUTION = "NODE_EXECUTION"
20
+ PROVIDER_CREDENTIALS_UNAVAILABLE = "PROVIDER_CREDENTIALS_UNAVAILABLE"
20
21
  PROVIDER_ERROR = "PROVIDER_ERROR"
21
22
  USER_DEFINED_ERROR = "USER_DEFINED_ERROR"
22
23
  WORKFLOW_CANCELLED = "WORKFLOW_CANCELLED"
@@ -89,6 +89,7 @@ mock_node_uuid = str(uuid4_from_hash(MockNode.__qualname__))
89
89
  "foo": "bar",
90
90
  },
91
91
  "display_context": None,
92
+ "initial_state": None,
92
93
  },
93
94
  "parent": None,
94
95
  },
@@ -54,8 +54,10 @@ class WorkflowEventDisplayContext(UniversalBaseModel):
54
54
  workflow_outputs: Dict[str, UUID]
55
55
 
56
56
 
57
- class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsType]):
57
+ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsType, StateType]):
58
58
  inputs: InputsType
59
+ initial_state: Optional[StateType] = None
60
+
59
61
  # It is still the responsibility of the workflow server to populate this context. The SDK's
60
62
  # Workflow Runner will always leave this field None.
61
63
  #
@@ -67,15 +69,23 @@ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsT
67
69
  def serialize_inputs(self, inputs: InputsType, _info: Any) -> Dict[str, Any]:
68
70
  return default_serializer(inputs)
69
71
 
72
+ @field_serializer("initial_state")
73
+ def serialize_initial_state(self, initial_state: Optional[StateType], _info: Any) -> Optional[Dict[str, Any]]:
74
+ return default_serializer(initial_state)
75
+
70
76
 
71
- class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType]):
77
+ class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType, StateType]):
72
78
  name: Literal["workflow.execution.initiated"] = "workflow.execution.initiated"
73
- body: WorkflowExecutionInitiatedBody[InputsType]
79
+ body: WorkflowExecutionInitiatedBody[InputsType, StateType]
74
80
 
75
81
  @property
76
82
  def inputs(self) -> InputsType:
77
83
  return self.body.inputs
78
84
 
85
+ @property
86
+ def initial_state(self) -> Optional[StateType]:
87
+ return self.body.initial_state
88
+
79
89
 
80
90
  class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
81
91
  output: BaseOutput
@@ -4,11 +4,13 @@ from typing import Optional
4
4
 
5
5
  from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
6
6
  from vellum.core.pydantic_utilities import UniversalBaseModel
7
+ from vellum.workflows.constants import undefined
7
8
  from vellum.workflows.descriptors.tests.test_utils import FixtureState
8
9
  from vellum.workflows.inputs.base import BaseInputs
9
10
  from vellum.workflows.nodes import FinalOutputNode
10
11
  from vellum.workflows.nodes.bases.base import BaseNode
11
12
  from vellum.workflows.outputs.base import BaseOutputs
13
+ from vellum.workflows.references.output import OutputReference
12
14
  from vellum.workflows.state.base import BaseState, StateMeta
13
15
 
14
16
 
@@ -259,3 +261,25 @@ def test_resolve_value__for_falsy_values(falsy_value, expected_type):
259
261
 
260
262
  # THEN the output has the correct value
261
263
  assert falsy_output.value == falsy_value
264
+
265
+
266
+ def test_node_outputs__inherits_instance():
267
+ # GIVEN a node with two outputs, one with and one without a default instance
268
+ class MyNode(BaseNode):
269
+ class Outputs:
270
+ foo: str
271
+ bar = "hello"
272
+
273
+ # AND a node that inherits from MyNode
274
+ class InheritedNode(MyNode):
275
+ pass
276
+
277
+ # WHEN we reference each output
278
+ foo_output = InheritedNode.Outputs.foo
279
+ bar_output = InheritedNode.Outputs.bar
280
+
281
+ # THEN the output reference instances are correct
282
+ assert isinstance(foo_output, OutputReference)
283
+ assert foo_output.instance is undefined
284
+ assert isinstance(bar_output, OutputReference)
285
+ assert bar_output.instance == "hello"
@@ -4,7 +4,6 @@ from vellum.workflows.context import execution_context, get_parent_context
4
4
  from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
5
5
  from vellum.workflows.events.workflow import is_workflow_event
6
6
  from vellum.workflows.exceptions import NodeException
7
- from vellum.workflows.nodes.bases import BaseNode
8
7
  from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
9
8
  from vellum.workflows.nodes.utils import create_adornment
10
9
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
@@ -24,7 +23,7 @@ class TryNode(BaseAdornmentNode[StateType], Generic[StateType]):
24
23
 
25
24
  on_error_code: Optional[WorkflowErrorCode] = None
26
25
 
27
- class Outputs(BaseNode.Outputs):
26
+ class Outputs(BaseAdornmentNode.Outputs):
28
27
  error: Optional[WorkflowError] = None
29
28
 
30
29
  def run(self) -> Iterator[BaseOutput]:
@@ -69,7 +69,13 @@ class BasePromptNode(BaseNode, Generic[StateType]):
69
69
  return outputs
70
70
 
71
71
  def _handle_api_error(self, e: ApiError):
72
- if e.status_code and e.status_code >= 400 and e.status_code < 500 and isinstance(e.body, dict):
72
+ if e.status_code and e.status_code == 403 and isinstance(e.body, dict):
73
+ raise NodeException(
74
+ message=e.body.get("detail", "Provider credentials is missing or unavailable"),
75
+ code=WorkflowErrorCode.PROVIDER_CREDENTIALS_UNAVAILABLE,
76
+ )
77
+
78
+ elif e.status_code and e.status_code >= 400 and e.status_code < 500 and isinstance(e.body, dict):
73
79
  raise NodeException(
74
80
  message=e.body.get("detail", "Failed to execute Prompt"),
75
81
  code=WorkflowErrorCode.INVALID_INPUTS,
@@ -170,8 +170,13 @@ def test_inline_prompt_node__function_definitions(vellum_adhoc_prompt_client):
170
170
  WorkflowErrorCode.INTERNAL_ERROR,
171
171
  "Failed to execute Prompt",
172
172
  ),
173
+ (
174
+ ApiError(status_code=403, body={"detail": "Provider credentials is missing or unavailable"}),
175
+ WorkflowErrorCode.PROVIDER_CREDENTIALS_UNAVAILABLE,
176
+ "Provider credentials is missing or unavailable",
177
+ ),
173
178
  ],
174
- ids=["404", "invalid_dict", "invalid_body", "no_status_code", "500"],
179
+ ids=["404", "invalid_dict", "invalid_body", "no_status_code", "500", "403"],
175
180
  )
176
181
  def test_inline_prompt_node__api_error__invalid_inputs_node_exception(
177
182
  vellum_adhoc_prompt_client, exception, expected_code, expected_message
@@ -491,3 +491,29 @@ def test_prompt_deployment_node__no_fallbacks(vellum_client):
491
491
 
492
492
  # AND the client should have been called only once (for the primary model)
493
493
  assert vellum_client.execute_prompt_stream.call_count == 1
494
+
495
+
496
+ def test_prompt_deployment_node__provider_credentials_missing(vellum_client):
497
+ # GIVEN a Prompt Deployment Node
498
+ class TestPromptDeploymentNode(PromptDeploymentNode):
499
+ deployment = "test_deployment"
500
+ prompt_inputs = {}
501
+
502
+ # AND the client responds with a 403 error of provider credentials missing
503
+ primary_error = ApiError(
504
+ body={"detail": "Provider credentials is missing or unavailable"},
505
+ status_code=403,
506
+ )
507
+
508
+ vellum_client.execute_prompt_stream.side_effect = primary_error
509
+
510
+ # WHEN we run the node
511
+ node = TestPromptDeploymentNode()
512
+
513
+ # THEN the node should raise an exception
514
+ with pytest.raises(NodeException) as exc_info:
515
+ list(node.run())
516
+
517
+ # AND the exception should contain the original error message
518
+ assert exc_info.value.message == "Provider credentials is missing or unavailable"
519
+ assert exc_info.value.code == WorkflowErrorCode.PROVIDER_CREDENTIALS_UNAVAILABLE
@@ -0,0 +1,3 @@
1
+ from vellum.workflows.nodes.experimental.tool_calling_node.node import ToolCallingNode
2
+
3
+ __all__ = ["ToolCallingNode"]
@@ -0,0 +1,147 @@
1
+ from collections.abc import Callable
2
+ from typing import Any, ClassVar, Dict, List, Optional, cast
3
+
4
+ from vellum import ChatMessage, FunctionDefinition, PromptBlock
5
+ from vellum.client.types.chat_message_request import ChatMessageRequest
6
+ from vellum.workflows.context import execution_context, get_parent_context
7
+ from vellum.workflows.errors.types import WorkflowErrorCode
8
+ from vellum.workflows.exceptions import NodeException
9
+ from vellum.workflows.graph.graph import Graph
10
+ from vellum.workflows.inputs.base import BaseInputs
11
+ from vellum.workflows.nodes.bases import BaseNode
12
+ from vellum.workflows.nodes.experimental.tool_calling_node.utils import (
13
+ ToolRouterNode,
14
+ create_function_node,
15
+ create_tool_router_node,
16
+ )
17
+ from vellum.workflows.outputs.base import BaseOutputs
18
+ from vellum.workflows.state.base import BaseState
19
+ from vellum.workflows.state.context import WorkflowContext
20
+ from vellum.workflows.types.core import EntityInputsInterface
21
+ from vellum.workflows.workflows.base import BaseWorkflow
22
+
23
+
24
+ class ToolCallingNode(BaseNode):
25
+ """
26
+ A Node that dynamically invokes the provided functions to the underlying Prompt
27
+
28
+ Attributes:
29
+ ml_model: str - The model to use for tool calling (e.g., "gpt-4o-mini")
30
+ blocks: List[PromptBlock] - The prompt blocks to use (same format as InlinePromptNode)
31
+ functions: List[FunctionDefinition] - The functions that can be called
32
+ function_callables: List[Callable] - The callables that can be called
33
+ prompt_inputs: Optional[EntityInputsInterface] - Mapping of input variable names to values
34
+ """
35
+
36
+ ml_model: ClassVar[str] = "gpt-4o-mini"
37
+ blocks: ClassVar[List[PromptBlock]] = []
38
+ functions: ClassVar[List[FunctionDefinition]] = []
39
+ function_callables: ClassVar[Dict[str, Callable[..., Any]]] = {}
40
+ prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
41
+ # TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
42
+ max_tool_calls: ClassVar[int] = 1
43
+
44
+ class Outputs(BaseOutputs):
45
+ """
46
+ The outputs of the ToolCallingNode.
47
+
48
+ text: The final text response after tool calling
49
+ chat_history: The complete chat history including tool calls
50
+ """
51
+
52
+ text: str = ""
53
+ chat_history: List[ChatMessage] = []
54
+
55
+ def run(self) -> Outputs:
56
+ """
57
+ Run the tool calling workflow.
58
+
59
+ This dynamically builds a graph with router and function nodes,
60
+ then executes the workflow.
61
+ """
62
+ self._validate_functions()
63
+
64
+ initial_chat_history = []
65
+
66
+ # Extract chat history from prompt inputs if available
67
+ if self.prompt_inputs and "chat_history" in self.prompt_inputs:
68
+ chat_history_input = self.prompt_inputs["chat_history"]
69
+ if isinstance(chat_history_input, list) and all(
70
+ isinstance(msg, (ChatMessage, ChatMessageRequest)) for msg in chat_history_input
71
+ ):
72
+ initial_chat_history = [
73
+ msg if isinstance(msg, ChatMessage) else ChatMessage.model_validate(msg.model_dump())
74
+ for msg in chat_history_input
75
+ ]
76
+
77
+ self._build_graph()
78
+
79
+ with execution_context(parent_context=get_parent_context()):
80
+
81
+ class ToolCallingState(BaseState):
82
+ chat_history: List[ChatMessage] = initial_chat_history
83
+
84
+ class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
85
+ graph = self._graph
86
+
87
+ class Outputs(BaseWorkflow.Outputs):
88
+ text: str = ToolRouterNode.Outputs.text
89
+ chat_history: List[ChatMessage] = ToolCallingState.chat_history
90
+
91
+ subworkflow = ToolCallingWorkflow(
92
+ parent_state=self.state,
93
+ context=WorkflowContext.create_from(self._context),
94
+ )
95
+
96
+ terminal_event = subworkflow.run()
97
+
98
+ if terminal_event.name == "workflow.execution.paused":
99
+ raise NodeException(
100
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
101
+ message="Subworkflow unexpectedly paused",
102
+ )
103
+ elif terminal_event.name == "workflow.execution.fulfilled":
104
+ node_outputs = self.Outputs()
105
+
106
+ for output_descriptor, output_value in terminal_event.outputs:
107
+ setattr(node_outputs, output_descriptor.name, output_value)
108
+
109
+ return node_outputs
110
+ elif terminal_event.name == "workflow.execution.rejected":
111
+ raise Exception(f"Workflow execution rejected: {terminal_event.error}")
112
+
113
+ raise Exception(f"Unexpected workflow event: {terminal_event.name}")
114
+
115
+ def _build_graph(self) -> None:
116
+ self.tool_router_node = create_tool_router_node(
117
+ ml_model=self.ml_model,
118
+ blocks=self.blocks,
119
+ functions=self.functions,
120
+ prompt_inputs=self.prompt_inputs,
121
+ )
122
+
123
+ self._function_nodes = {
124
+ function.name: create_function_node(
125
+ function=function,
126
+ function_callable=cast(Callable[..., Any], self.function_callables[function.name]), # type: ignore
127
+ )
128
+ for function in self.functions
129
+ }
130
+
131
+ graph_set = set()
132
+
133
+ # Add connections from ports of router to function nodes and back to router
134
+ for function_name, FunctionNodeClass in self._function_nodes.items():
135
+ router_port = getattr(self.tool_router_node.Ports, function_name) # type: ignore # mypy thinks name is still optional # noqa: E501
136
+ edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
137
+ graph_set.add(edge_graph)
138
+
139
+ default_port = getattr(self.tool_router_node.Ports, "default")
140
+ graph_set.add(default_port)
141
+
142
+ self._graph = Graph.from_set(graph_set)
143
+
144
+ def _validate_functions(self) -> None:
145
+ for function in self.functions:
146
+ if function.name is None:
147
+ raise ValueError("Function name is required")
@@ -0,0 +1,132 @@
1
+ from collections.abc import Callable
2
+ import json
3
+ from typing import Any, Iterator, List, Optional, Type, cast
4
+
5
+ from vellum import ChatMessage, FunctionDefinition, PromptBlock
6
+ from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
7
+ from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
8
+ from vellum.client.types.variable_prompt_block import VariablePromptBlock
9
+ from vellum.workflows.nodes.bases import BaseNode
10
+ from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
11
+ from vellum.workflows.outputs.base import BaseOutput
12
+ from vellum.workflows.ports.port import Port
13
+ from vellum.workflows.references.lazy import LazyReference
14
+ from vellum.workflows.types.core import EntityInputsInterface
15
+
16
+
17
+ class FunctionNode(BaseNode):
18
+ """Node that executes a specific function."""
19
+
20
+ function: FunctionDefinition
21
+
22
+
23
+ class ToolRouterNode(InlinePromptNode):
24
+ def run(self) -> Iterator[BaseOutput]:
25
+ self.prompt_inputs = {**self.prompt_inputs, "chat_history": self.state.chat_history} # type: ignore
26
+ generator = super().run()
27
+ for output in generator:
28
+ if output.name == "results" and output.value:
29
+ values = cast(List[Any], output.value)
30
+ if values and len(values) > 0:
31
+ if values[0].type == "STRING":
32
+ self.state.chat_history.append(ChatMessage(role="ASSISTANT", text=values[0].value))
33
+ elif values[0].type == "FUNCTION_CALL":
34
+ function_call = values[0].value
35
+ if function_call is not None:
36
+ self.state.chat_history.append(
37
+ ChatMessage(
38
+ role="FUNCTION",
39
+ content=FunctionCallChatMessageContent(
40
+ value=FunctionCallChatMessageContentValue(
41
+ name=function_call.name,
42
+ arguments=function_call.arguments,
43
+ id=function_call.id,
44
+ ),
45
+ ),
46
+ )
47
+ )
48
+ yield output
49
+
50
+
51
+ def create_tool_router_node(
52
+ ml_model: str,
53
+ blocks: List[PromptBlock],
54
+ functions: List[FunctionDefinition],
55
+ prompt_inputs: Optional[EntityInputsInterface],
56
+ ) -> Type[ToolRouterNode]:
57
+ Ports = type("Ports", (), {})
58
+ for function in functions:
59
+ if function.name is None:
60
+ # We should not raise an error here since we filter out functions without names
61
+ raise ValueError("Function name is required")
62
+
63
+ function_name = function.name
64
+ port_condition = LazyReference(
65
+ lambda: (
66
+ ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
67
+ & ToolRouterNode.Outputs.results[0]["value"]["name"].equals(function_name)
68
+ )
69
+ )
70
+ port = Port.on_if(port_condition)
71
+ setattr(Ports, function_name, port)
72
+
73
+ setattr(Ports, "default", Port.on_else())
74
+
75
+ # Add a chat history block to blocks
76
+ blocks.append(
77
+ VariablePromptBlock(
78
+ block_type="VARIABLE",
79
+ input_variable="chat_history",
80
+ state=None,
81
+ cache_config=None,
82
+ )
83
+ )
84
+
85
+ node = type(
86
+ "ToolRouterNode",
87
+ (ToolRouterNode,),
88
+ {
89
+ "ml_model": ml_model,
90
+ "blocks": blocks,
91
+ "functions": functions,
92
+ "prompt_inputs": prompt_inputs,
93
+ "Ports": Ports,
94
+ "__module__": __name__,
95
+ },
96
+ )
97
+
98
+ return node
99
+
100
+
101
+ def create_function_node(function: FunctionDefinition, function_callable: Callable[..., Any]) -> Type[FunctionNode]:
102
+ """
103
+ Create a FunctionNode class for a given function.
104
+
105
+ This ensures the callable is properly registered and can be called with the expected arguments.
106
+ """
107
+
108
+ # Create a class-level wrapper that calls the original function
109
+ def execute_function(self) -> BaseNode.Outputs:
110
+ outputs = self.state.meta.node_outputs.get(ToolRouterNode.Outputs.text)
111
+ # first parse into json
112
+ outputs = json.loads(outputs)
113
+ arguments = outputs["arguments"]
114
+
115
+ # Call the original function directly with the arguments
116
+ result = function_callable(**arguments)
117
+
118
+ self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
119
+
120
+ return self.Outputs()
121
+
122
+ node = type(
123
+ f"FunctionNode_{function.name}",
124
+ (FunctionNode,),
125
+ {
126
+ "function": function,
127
+ "run": execute_function,
128
+ "__module__": __name__,
129
+ },
130
+ )
131
+
132
+ return node
@@ -57,10 +57,12 @@ def create_adornment(
57
57
  class Subworkflow(BaseWorkflow):
58
58
  graph = inner_cls
59
59
 
60
- # mypy is wrong here, this works and is defined
61
- class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
60
+ class Outputs(BaseWorkflow.Outputs):
62
61
  pass
63
62
 
63
+ for output_reference in inner_cls.Outputs:
64
+ setattr(Subworkflow.Outputs, output_reference.name, output_reference)
65
+
64
66
  dynamic_module = f"{inner_cls.__module__}.{inner_cls.__name__}.{ADORNMENT_MODULE_NAME}"
65
67
  # This dynamic module allows calls to `type_hints` to work
66
68
  sys.modules[dynamic_module] = ModuleType(dynamic_module)
@@ -147,8 +147,9 @@ class _BaseOutputsMeta(type):
147
147
  instance = vars(cls).get(name, undefined)
148
148
  if instance is undefined:
149
149
  for base in cls.__mro__[1:]:
150
- if hasattr(base, name):
151
- instance = getattr(base, name)
150
+ inherited_output_reference = getattr(base, name, undefined)
151
+ if isinstance(inherited_output_reference, OutputReference):
152
+ instance = inherited_output_reference.instance
152
153
  break
153
154
 
154
155
  types = infer_types(cls, name)
@@ -1,4 +1,6 @@
1
+ from functools import cached_property
1
2
  from queue import Queue
3
+ from uuid import UUID, uuid4
2
4
  from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Tuple, Type, TypeVar, cast
3
5
 
4
6
  from pydantic import GetCoreSchemaHandler
@@ -31,6 +33,24 @@ class OutputReference(BaseDescriptor[_OutputType], Generic[_OutputType]):
31
33
  def outputs_class(self) -> Type["BaseOutputs"]:
32
34
  return self._outputs_class
33
35
 
36
+ @cached_property
37
+ def id(self) -> UUID:
38
+ self._outputs_class = self._outputs_class
39
+
40
+ node_class = getattr(self._outputs_class, "_node_class", None)
41
+ if not node_class:
42
+ return uuid4()
43
+
44
+ output_ids = getattr(node_class, "__output_ids__", {})
45
+ if not isinstance(output_ids, dict):
46
+ return uuid4()
47
+
48
+ output_id = output_ids.get(self.name)
49
+ if not isinstance(output_id, UUID):
50
+ return uuid4()
51
+
52
+ return output_id
53
+
34
54
  def resolve(self, state: "BaseState") -> _OutputType:
35
55
  node_output = state.meta.node_outputs.get(self, undefined)
36
56
  if isinstance(node_output, Queue):