vellum-ai 1.4.0__py3-none-any.whl → 1.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. vellum/client/core/client_wrapper.py +2 -2
  2. vellum/workflows/constants.py +4 -0
  3. vellum/workflows/emitters/base.py +8 -0
  4. vellum/workflows/emitters/vellum_emitter.py +10 -0
  5. vellum/workflows/events/exception_handling.py +58 -0
  6. vellum/workflows/events/tests/test_event.py +27 -0
  7. vellum/workflows/exceptions.py +11 -6
  8. vellum/workflows/inputs/base.py +1 -0
  9. vellum/workflows/inputs/dataset_row.py +2 -2
  10. vellum/workflows/nodes/bases/base.py +12 -1
  11. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +6 -0
  12. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +16 -2
  13. vellum/workflows/nodes/displayable/final_output_node/node.py +59 -0
  14. vellum/workflows/nodes/displayable/final_output_node/tests/test_node.py +40 -1
  15. vellum/workflows/nodes/displayable/tool_calling_node/node.py +3 -0
  16. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +64 -0
  17. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +30 -41
  18. vellum/workflows/nodes/mocks.py +15 -4
  19. vellum/workflows/tests/test_dataset_row.py +29 -0
  20. vellum/workflows/types/core.py +13 -2
  21. vellum/workflows/types/definition.py +13 -1
  22. vellum/workflows/utils/functions.py +63 -26
  23. vellum/workflows/utils/tests/test_functions.py +10 -6
  24. vellum/workflows/vellum_client.py +7 -1
  25. vellum/workflows/workflows/base.py +8 -0
  26. {vellum_ai-1.4.0.dist-info → vellum_ai-1.4.2.dist-info}/METADATA +1 -1
  27. {vellum_ai-1.4.0.dist-info → vellum_ai-1.4.2.dist-info}/RECORD +38 -36
  28. vellum_cli/tests/test_pull.py +1 -0
  29. vellum_cli/tests/test_push.py +2 -0
  30. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +1 -3
  31. vellum_ee/workflows/display/nodes/vellum/tests/test_final_output_node.py +78 -0
  32. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_inline_workflow_serialization.py +5 -0
  33. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +5 -0
  34. vellum_ee/workflows/display/types.py +3 -0
  35. vellum_ee/workflows/display/workflows/base_workflow_display.py +6 -0
  36. {vellum_ai-1.4.0.dist-info → vellum_ai-1.4.2.dist-info}/LICENSE +0 -0
  37. {vellum_ai-1.4.0.dist-info → vellum_ai-1.4.2.dist-info}/WHEEL +0 -0
  38. {vellum_ai-1.4.0.dist-info → vellum_ai-1.4.2.dist-info}/entry_points.txt +0 -0
@@ -27,10 +27,10 @@ class BaseClientWrapper:
27
27
 
28
28
  def get_headers(self) -> typing.Dict[str, str]:
29
29
  headers: typing.Dict[str, str] = {
30
- "User-Agent": "vellum-ai/1.4.0",
30
+ "User-Agent": "vellum-ai/1.4.2",
31
31
  "X-Fern-Language": "Python",
32
32
  "X-Fern-SDK-Name": "vellum-ai",
33
- "X-Fern-SDK-Version": "1.4.0",
33
+ "X-Fern-SDK-Version": "1.4.2",
34
34
  **(self.get_custom_headers() or {}),
35
35
  }
36
36
  if self._api_version is not None:
@@ -58,3 +58,7 @@ class APIRequestMethod(Enum):
58
58
  class AuthorizationType(Enum):
59
59
  BEARER_TOKEN = "BEARER_TOKEN"
60
60
  API_KEY = "API_KEY"
61
+
62
+
63
+ class VellumIntegrationProviderType(Enum):
64
+ COMPOSIO = "COMPOSIO"
@@ -29,3 +29,11 @@ class BaseWorkflowEmitter(ABC):
29
29
  @abstractmethod
30
30
  def snapshot_state(self, state: BaseState) -> None:
31
31
  pass
32
+
33
+ @abstractmethod
34
+ def join(self) -> None:
35
+ """
36
+ Wait for any background threads or timers used by this emitter to complete.
37
+ This ensures all pending work is finished before the workflow terminates.
38
+ """
39
+ pass
@@ -135,3 +135,13 @@ class VellumEmitter(BaseWorkflowEmitter):
135
135
  request=events, # type: ignore[arg-type]
136
136
  request_options=request_options,
137
137
  )
138
+
139
+ def join(self) -> None:
140
+ """
141
+ Wait for any background threads or timers used by this emitter to complete.
142
+ This ensures all pending work is finished before the workflow terminates.
143
+ """
144
+ self._flush_events()
145
+
146
+ if self._debounce_timer and self._debounce_timer.is_alive():
147
+ self._debounce_timer.join()
@@ -0,0 +1,58 @@
1
+ from uuid import uuid4
2
+ from typing import Generator
3
+
4
+ from vellum.workflows.context import get_execution_context
5
+ from vellum.workflows.events.stream import WorkflowEventGenerator
6
+ from vellum.workflows.events.workflow import (
7
+ WorkflowEvent,
8
+ WorkflowEventStream,
9
+ WorkflowExecutionInitiatedBody,
10
+ WorkflowExecutionInitiatedEvent,
11
+ WorkflowExecutionRejectedBody,
12
+ WorkflowExecutionRejectedEvent,
13
+ )
14
+ from vellum.workflows.exceptions import WorkflowInitializationException
15
+ from vellum.workflows.inputs import BaseInputs
16
+
17
+
18
+ def stream_initialization_exception(
19
+ exception: WorkflowInitializationException,
20
+ ) -> WorkflowEventStream:
21
+ """
22
+ Stream a workflow initiated event followed by a workflow rejected event for an initialization exception.
23
+
24
+ Args:
25
+ exception: The WorkflowInitializationException to stream events for
26
+
27
+ Returns:
28
+ WorkflowEventGenerator yielding initiated and rejected events
29
+ """
30
+
31
+ execution_context = get_execution_context()
32
+ span_id = uuid4()
33
+
34
+ def _generate_events() -> Generator[WorkflowEvent, None, None]:
35
+ initiated_event: WorkflowEvent = WorkflowExecutionInitiatedEvent(
36
+ trace_id=execution_context.trace_id,
37
+ span_id=span_id,
38
+ body=WorkflowExecutionInitiatedBody(
39
+ workflow_definition=exception.definition,
40
+ inputs=BaseInputs(),
41
+ initial_state=None,
42
+ ),
43
+ parent=execution_context.parent_context,
44
+ )
45
+ yield initiated_event
46
+
47
+ rejected_event = WorkflowExecutionRejectedEvent(
48
+ trace_id=execution_context.trace_id,
49
+ span_id=span_id,
50
+ body=WorkflowExecutionRejectedBody(
51
+ workflow_definition=exception.definition,
52
+ error=exception.error,
53
+ ),
54
+ parent=execution_context.parent_context,
55
+ )
56
+ yield rejected_event
57
+
58
+ return WorkflowEventGenerator(_generate_events(), span_id)
@@ -7,6 +7,7 @@ from deepdiff import DeepDiff
7
7
  from vellum.client.core.pydantic_utilities import UniversalBaseModel
8
8
  from vellum.workflows.constants import undefined
9
9
  from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode
10
+ from vellum.workflows.events.exception_handling import stream_initialization_exception
10
11
  from vellum.workflows.events.node import (
11
12
  NodeExecutionFulfilledBody,
12
13
  NodeExecutionFulfilledEvent,
@@ -26,6 +27,7 @@ from vellum.workflows.events.workflow import (
26
27
  WorkflowExecutionStreamingBody,
27
28
  WorkflowExecutionStreamingEvent,
28
29
  )
30
+ from vellum.workflows.exceptions import WorkflowInitializationException
29
31
  from vellum.workflows.inputs.base import BaseInputs
30
32
  from vellum.workflows.nodes.bases.base import BaseNode
31
33
  from vellum.workflows.outputs.base import BaseOutput
@@ -460,3 +462,28 @@ def test_parent_context__deserialize_from_json__invalid_parent_context():
460
462
  assert event.parent.type == "UNKNOWN"
461
463
  assert event.parent.span_id == UUID("123e4567-e89b-12d3-a456-426614174000")
462
464
  assert event.parent.parent is None
465
+
466
+
467
+ def test_workflow_event_generator_stream_initialization_exception():
468
+ """
469
+ Tests that stream_initialization_exception yields both initiated and rejected events with proper correlation.
470
+ """
471
+ exception = WorkflowInitializationException("Test initialization error", workflow_definition=MockWorkflow)
472
+
473
+ events = list(stream_initialization_exception(exception))
474
+
475
+ assert len(events) == 2
476
+
477
+ initiated_event = events[0]
478
+ assert initiated_event.name == "workflow.execution.initiated"
479
+ assert initiated_event.body.inputs is not None
480
+ assert initiated_event.body.initial_state is None
481
+ assert initiated_event.body.workflow_definition == MockWorkflow
482
+
483
+ rejected_event = events[1]
484
+ assert rejected_event.name == "workflow.execution.rejected"
485
+ assert rejected_event.body.error.message == "Test initialization error"
486
+ assert rejected_event.body.workflow_definition == MockWorkflow
487
+
488
+ assert initiated_event.trace_id == rejected_event.trace_id
489
+ assert initiated_event.span_id == rejected_event.span_id
@@ -1,7 +1,10 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Type
2
2
 
3
3
  from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
4
4
 
5
+ if TYPE_CHECKING:
6
+ from vellum.workflows.workflows.base import BaseWorkflow
7
+
5
8
 
6
9
  class NodeException(Exception):
7
10
  def __init__(
@@ -29,9 +32,15 @@ class NodeException(Exception):
29
32
 
30
33
 
31
34
  class WorkflowInitializationException(Exception):
32
- def __init__(self, message: str, code: WorkflowErrorCode = WorkflowErrorCode.INVALID_INPUTS):
35
+ def __init__(
36
+ self,
37
+ message: str,
38
+ workflow_definition: Type["BaseWorkflow"],
39
+ code: WorkflowErrorCode = WorkflowErrorCode.INVALID_INPUTS,
40
+ ):
33
41
  self.message = message
34
42
  self.code = code
43
+ self.definition = workflow_definition
35
44
  super().__init__(message)
36
45
 
37
46
  @property
@@ -40,7 +49,3 @@ class WorkflowInitializationException(Exception):
40
49
  message=self.message,
41
50
  code=self.code,
42
51
  )
43
-
44
- @staticmethod
45
- def of(workflow_error: WorkflowError) -> "WorkflowInitializationException":
46
- return WorkflowInitializationException(message=workflow_error.message, code=workflow_error.code)
@@ -95,6 +95,7 @@ class BaseInputs(metaclass=_BaseInputsMeta):
95
95
  raise WorkflowInitializationException(
96
96
  message=f"Required input variables {name} should have defined value",
97
97
  code=WorkflowErrorCode.INVALID_INPUTS,
98
+ workflow_definition=self.__class__.__parent_class__,
98
99
  )
99
100
 
100
101
  # If value provided in kwargs, set it on the instance
@@ -1,6 +1,6 @@
1
1
  from typing import Any, Dict
2
2
 
3
- from pydantic import field_serializer
3
+ from pydantic import Field, field_serializer
4
4
 
5
5
  from vellum.client.core.pydantic_utilities import UniversalBaseModel
6
6
  from vellum.workflows.inputs.base import BaseInputs
@@ -16,7 +16,7 @@ class DatasetRow(UniversalBaseModel):
16
16
  """
17
17
 
18
18
  label: str
19
- inputs: BaseInputs
19
+ inputs: BaseInputs = Field(default_factory=BaseInputs)
20
20
 
21
21
  @field_serializer("inputs")
22
22
  def serialize_inputs(self, inputs: BaseInputs) -> Dict[str, Any]:
@@ -1,4 +1,4 @@
1
- from abc import ABC, ABCMeta
1
+ from abc import ABC, ABCMeta, abstractmethod
2
2
  from dataclasses import field
3
3
  from functools import cached_property, reduce
4
4
  import inspect
@@ -215,6 +215,17 @@ class BaseNodeMeta(ABCMeta):
215
215
  yield attr_value
216
216
  yielded_attr_names.add(attr_name)
217
217
 
218
+ @abstractmethod
219
+ def __validate__(cls) -> None:
220
+ """
221
+ Validates the node.
222
+ Subclasses can override this method to implement their specific validation logic.
223
+ Called during serialization or explicit validation.
224
+
225
+ Default implementation performs no validation.
226
+ """
227
+ pass
228
+
218
229
 
219
230
  class _BaseNodeTriggerMeta(type):
220
231
  def __eq__(self, other: Any) -> bool:
@@ -112,4 +112,10 @@ class BasePromptNode(BaseNode[StateType], Generic[StateType]):
112
112
  if not target_node_output:
113
113
  return False
114
114
 
115
+ if not isinstance(target_node_output.instance, OutputReference):
116
+ return False
117
+
118
+ if target_node_output.instance.name != "text":
119
+ return False
120
+
115
121
  return True
@@ -45,9 +45,15 @@ from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePrompt
45
45
  from vellum.workflows.nodes.displayable.bases.utils import process_additional_prompt_outputs
46
46
  from vellum.workflows.outputs import BaseOutput
47
47
  from vellum.workflows.types import MergeBehavior
48
- from vellum.workflows.types.definition import DeploymentDefinition, MCPServer
48
+ from vellum.workflows.types.definition import (
49
+ ComposioToolDefinition,
50
+ DeploymentDefinition,
51
+ MCPServer,
52
+ VellumIntegrationToolDefinition,
53
+ )
49
54
  from vellum.workflows.types.generics import StateType, is_workflow_class
50
55
  from vellum.workflows.utils.functions import (
56
+ compile_composio_tool_definition,
51
57
  compile_function_definition,
52
58
  compile_inline_workflow_function_definition,
53
59
  compile_mcp_tool_definition,
@@ -134,7 +140,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
134
140
  elif isinstance(function, DeploymentDefinition):
135
141
  normalized_functions.append(
136
142
  compile_workflow_deployment_function_definition(
137
- function.model_dump(),
143
+ function,
138
144
  vellum_client=self._context.vellum_client,
139
145
  )
140
146
  )
@@ -142,6 +148,14 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
142
148
  normalized_functions.append(compile_inline_workflow_function_definition(function))
143
149
  elif callable(function):
144
150
  normalized_functions.append(compile_function_definition(function))
151
+ elif isinstance(function, ComposioToolDefinition):
152
+ normalized_functions.append(compile_composio_tool_definition(function))
153
+ elif isinstance(function, VellumIntegrationToolDefinition):
154
+ # TODO: Implement compile_vellum_integration_tool_definition
155
+ raise NotImplementedError(
156
+ "VellumIntegrationToolDefinition support coming soon. "
157
+ "This will be implemented when compile_vellum_integration_tool_definition is created."
158
+ )
145
159
  elif isinstance(function, MCPServer):
146
160
  tool_definitions = compile_mcp_tool_definition(function)
147
161
  for tool_def in tool_definitions:
@@ -5,6 +5,7 @@ from vellum.workflows.nodes.bases import BaseNode
5
5
  from vellum.workflows.nodes.bases.base import BaseNodeMeta
6
6
  from vellum.workflows.nodes.utils import cast_to_output_type
7
7
  from vellum.workflows.ports import NodePorts
8
+ from vellum.workflows.references.output import OutputReference
8
9
  from vellum.workflows.types import MergeBehavior
9
10
  from vellum.workflows.types.generics import StateType
10
11
  from vellum.workflows.types.utils import get_original_base
@@ -27,6 +28,7 @@ class _FinalOutputNodeMeta(BaseNodeMeta):
27
28
  **annotations,
28
29
  "value": parent.get_output_type(),
29
30
  }
31
+
30
32
  return parent
31
33
 
32
34
  def get_output_type(cls) -> Type:
@@ -38,6 +40,63 @@ class _FinalOutputNodeMeta(BaseNodeMeta):
38
40
  else:
39
41
  return all_args[1]
40
42
 
43
+ def __validate__(cls) -> None:
44
+ cls._validate_output_type_consistency(cls)
45
+
46
+ @classmethod
47
+ def _validate_output_type_consistency(mcs, cls: Type) -> None:
48
+ """
49
+ Validates that the declared output type of FinalOutputNode matches
50
+ the type of the descriptor assigned to the 'value' attribute in its Outputs class.
51
+
52
+ Raises ValueError if there's a type mismatch.
53
+ """
54
+ if not hasattr(cls, "Outputs"):
55
+ return
56
+
57
+ outputs_class = cls.Outputs
58
+ if not hasattr(outputs_class, "value"):
59
+ return
60
+
61
+ declared_output_type = cls.get_output_type()
62
+ value_descriptor = None
63
+
64
+ if "value" in outputs_class.__dict__:
65
+ value_descriptor = outputs_class.__dict__["value"]
66
+ else:
67
+ value_descriptor = getattr(outputs_class, "value")
68
+
69
+ if isinstance(value_descriptor, OutputReference):
70
+ descriptor_types = value_descriptor.types
71
+
72
+ type_mismatch = True
73
+ for descriptor_type in descriptor_types:
74
+ if descriptor_type == declared_output_type:
75
+ type_mismatch = False
76
+ break
77
+ try:
78
+ if issubclass(descriptor_type, declared_output_type) or issubclass(
79
+ declared_output_type, descriptor_type
80
+ ):
81
+ type_mismatch = False
82
+ break
83
+ except TypeError:
84
+ # Handle cases where types aren't classes (e.g., Union)
85
+ if str(descriptor_type) == str(declared_output_type):
86
+ type_mismatch = False
87
+ break
88
+
89
+ if type_mismatch:
90
+ declared_type_name = getattr(declared_output_type, "__name__", str(declared_output_type))
91
+ descriptor_type_names = [getattr(t, "__name__", str(t)) for t in descriptor_types]
92
+
93
+ raise ValueError(
94
+ f"Output type mismatch in {cls.__name__}: "
95
+ f"FinalOutputNode is declared with output type '{declared_type_name}' "
96
+ f"but the 'value' descriptor has type(s) {descriptor_type_names}. "
97
+ f"The output descriptor type must match the declared FinalOutputNode output type."
98
+ )
99
+
41
100
 
42
101
  class FinalOutputNode(BaseNode[StateType], Generic[StateType, _OutputType], metaclass=_FinalOutputNodeMeta):
43
102
  """
@@ -2,10 +2,11 @@ import pytest
2
2
 
3
3
  from vellum.workflows.exceptions import NodeException
4
4
  from vellum.workflows.nodes.displayable.final_output_node import FinalOutputNode
5
+ from vellum.workflows.nodes.displayable.inline_prompt_node import InlinePromptNode
5
6
  from vellum.workflows.state.base import BaseState
6
7
 
7
8
 
8
- def test_final_output_node__mismatched_output_type():
9
+ def test_final_output_node__mismatched_output_type_should_raise_exception_when_ran():
9
10
  # GIVEN a FinalOutputNode with a mismatched output type
10
11
  class StringOutputNode(FinalOutputNode[BaseState, str]):
11
12
  class Outputs(FinalOutputNode.Outputs):
@@ -18,3 +19,41 @@ def test_final_output_node__mismatched_output_type():
18
19
 
19
20
  # THEN an error is raised
20
21
  assert str(exc_info.value) == "Expected an output of type 'str', but received 'dict'"
22
+
23
+
24
+ def test_final_output_node__mismatched_output_type_should_raise_exception():
25
+ # GIVEN a FinalOutputNode declared with list output type but has a string value type
26
+ class Output(FinalOutputNode[BaseState, list]):
27
+ """Output the extracted invoice line items as an array of objects."""
28
+
29
+ class Outputs(FinalOutputNode.Outputs):
30
+ value = InlinePromptNode.Outputs.text
31
+
32
+ # WHEN attempting to validate the node class
33
+ # THEN a ValueError should be raised during validation
34
+ with pytest.raises(ValueError) as exc_info:
35
+ Output.__validate__()
36
+
37
+ # AND the error message should indicate the type mismatch
38
+ assert (
39
+ str(exc_info.value)
40
+ == "Output type mismatch in Output: FinalOutputNode is declared with output type 'list' but "
41
+ "the 'value' descriptor has type(s) ['str']. The output descriptor type must match the "
42
+ "declared FinalOutputNode output type."
43
+ )
44
+
45
+
46
+ def test_final_output_node__matching_output_type_should_pass_validation():
47
+ # GIVEN a FinalOutputNode declared with correct matching types
48
+ class CorrectOutput(FinalOutputNode[BaseState, str]):
49
+ """Output with correct type matching."""
50
+
51
+ class Outputs(FinalOutputNode.Outputs):
52
+ value = InlinePromptNode.Outputs.text
53
+
54
+ # WHEN attempting to validate the node class
55
+ # THEN validation should pass without raising an exception
56
+ try:
57
+ CorrectOutput.__validate__()
58
+ except ValueError:
59
+ pytest.fail("Validation should not raise an exception for correct type matching")
@@ -2,6 +2,7 @@ from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set,
2
2
 
3
3
  from vellum import ChatMessage, PromptBlock
4
4
  from vellum.client.types.prompt_parameters import PromptParameters
5
+ from vellum.client.types.prompt_settings import PromptSettings
5
6
  from vellum.prompts.constants import DEFAULT_PROMPT_PARAMETERS
6
7
  from vellum.workflows.context import execution_context, get_parent_context
7
8
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -47,6 +48,7 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
47
48
  prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
48
49
  parameters: PromptParameters = DEFAULT_PROMPT_PARAMETERS
49
50
  max_prompt_iterations: ClassVar[Optional[int]] = 5
51
+ settings: ClassVar[Optional[Union[PromptSettings, Dict[str, Any]]]] = None
50
52
 
51
53
  class Outputs(BaseOutputs):
52
54
  """
@@ -150,6 +152,7 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
150
152
  max_prompt_iterations=self.max_prompt_iterations,
151
153
  process_parameters_method=process_parameters_method,
152
154
  process_blocks_method=process_blocks_method,
155
+ settings=self.settings,
153
156
  )
154
157
 
155
158
  # Create the router node (handles routing logic only)
@@ -5,6 +5,7 @@ from vellum.client.types.chat_message_prompt_block import ChatMessagePromptBlock
5
5
  from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
6
6
  from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
7
7
  from vellum.client.types.plain_text_prompt_block import PlainTextPromptBlock
8
+ from vellum.client.types.prompt_settings import PromptSettings
8
9
  from vellum.client.types.rich_text_prompt_block import RichTextPromptBlock
9
10
  from vellum.client.types.string_vellum_value import StringVellumValue
10
11
  from vellum.client.types.variable_prompt_block import VariablePromptBlock
@@ -250,3 +251,66 @@ def test_get_mcp_tool_name_snake_case():
250
251
 
251
252
  result = get_mcp_tool_name(mcp_tool)
252
253
  assert result == "github_server__create_repository"
254
+
255
+
256
+ def test_create_tool_prompt_node_settings_dict_stream_disabled(vellum_adhoc_prompt_client):
257
+ # GIVEN settings provided as dict with stream disabled
258
+ tool_prompt_node = create_tool_prompt_node(
259
+ ml_model="gpt-4o-mini",
260
+ blocks=[],
261
+ functions=[],
262
+ prompt_inputs=None,
263
+ parameters=DEFAULT_PROMPT_PARAMETERS,
264
+ max_prompt_iterations=1,
265
+ settings={"stream_enabled": False},
266
+ )
267
+
268
+ # AND the API mocks
269
+ def generate_non_stream_response(*args, **kwargs):
270
+ return FulfilledExecutePromptEvent(execution_id=str(uuid4()), outputs=[StringVellumValue(value="ok")])
271
+
272
+ vellum_adhoc_prompt_client.adhoc_execute_prompt.side_effect = generate_non_stream_response
273
+
274
+ # WHEN we run the node
275
+ node_instance = tool_prompt_node()
276
+ list(node_instance.run())
277
+
278
+ # THEN the node should have called the API correctly
279
+ assert node_instance.settings is not None
280
+ assert node_instance.settings.stream_enabled is False
281
+ assert vellum_adhoc_prompt_client.adhoc_execute_prompt.call_count == 1
282
+ assert vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.call_count == 0
283
+
284
+
285
+ def test_create_tool_prompt_node_settings_model_stream_enabled(vellum_adhoc_prompt_client):
286
+ # GIVEN settings provided as PromptSettings with stream enabled
287
+ tool_prompt_node = create_tool_prompt_node(
288
+ ml_model="gpt-4o-mini",
289
+ blocks=[],
290
+ functions=[],
291
+ prompt_inputs=None,
292
+ parameters=DEFAULT_PROMPT_PARAMETERS,
293
+ max_prompt_iterations=1,
294
+ settings=PromptSettings(stream_enabled=True),
295
+ )
296
+
297
+ # AND the API mocks
298
+ def generate_stream_events(*args, **kwargs):
299
+ execution_id = str(uuid4())
300
+ events = [
301
+ InitiatedExecutePromptEvent(execution_id=execution_id),
302
+ FulfilledExecutePromptEvent(execution_id=execution_id, outputs=[StringVellumValue(value="ok")]),
303
+ ]
304
+ yield from events
305
+
306
+ vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_stream_events
307
+
308
+ # WHEN we run the node
309
+ node_instance = tool_prompt_node()
310
+ list(node_instance.run())
311
+
312
+ # THEN the node should have called the API correctly
313
+ assert node_instance.settings is not None
314
+ assert node_instance.settings.stream_enabled is True
315
+ assert vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.call_count == 1
316
+ assert vellum_adhoc_prompt_client.adhoc_execute_prompt.call_count == 0
@@ -9,9 +9,9 @@ from vellum.client.types.array_chat_message_content import ArrayChatMessageConte
9
9
  from vellum.client.types.array_chat_message_content_item import ArrayChatMessageContentItem
10
10
  from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
11
11
  from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
12
- from vellum.client.types.function_definition import FunctionDefinition
13
12
  from vellum.client.types.prompt_output import PromptOutput
14
13
  from vellum.client.types.prompt_parameters import PromptParameters
14
+ from vellum.client.types.prompt_settings import PromptSettings
15
15
  from vellum.client.types.string_chat_message_content import StringChatMessageContent
16
16
  from vellum.client.types.variable_prompt_block import VariablePromptBlock
17
17
  from vellum.workflows.descriptors.base import BaseDescriptor
@@ -31,7 +31,13 @@ from vellum.workflows.ports.port import Port
31
31
  from vellum.workflows.state import BaseState
32
32
  from vellum.workflows.state.encoder import DefaultStateEncoder
33
33
  from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior, Tool, ToolBase
34
- from vellum.workflows.types.definition import ComposioToolDefinition, DeploymentDefinition, MCPServer, MCPToolDefinition
34
+ from vellum.workflows.types.definition import (
35
+ ComposioToolDefinition,
36
+ DeploymentDefinition,
37
+ MCPServer,
38
+ MCPToolDefinition,
39
+ VellumIntegrationToolDefinition,
40
+ )
35
41
  from vellum.workflows.types.generics import is_workflow_class
36
42
  from vellum.workflows.utils.functions import compile_mcp_tool_definition, get_mcp_tool_name
37
43
 
@@ -274,36 +280,6 @@ class ElseNode(BaseNode[ToolCallingState]):
274
280
  return self.Outputs()
275
281
 
276
282
 
277
- def _hydrate_composio_tool_definition(tool_def: ComposioToolDefinition) -> FunctionDefinition:
278
- """Hydrate a ComposioToolDefinition with detailed information from the Composio API.
279
-
280
- Args:
281
- tool_def: The basic ComposioToolDefinition to enhance
282
-
283
- Returns:
284
- FunctionDefinition with detailed parameters and description
285
- """
286
- try:
287
- composio_service = ComposioService()
288
- tool_details = composio_service.get_tool_by_slug(tool_def.action)
289
-
290
- # Create a FunctionDefinition directly with proper field extraction
291
- return FunctionDefinition(
292
- name=tool_def.name,
293
- description=tool_details.get("description", tool_def.description),
294
- parameters=tool_details.get("input_parameters", {}),
295
- )
296
-
297
- except Exception as e:
298
- # If hydration fails (including no API key), log and return basic function definition
299
- logger.warning(f"Failed to enhance Composio tool '{tool_def.action}': {e}")
300
- return FunctionDefinition(
301
- name=tool_def.name,
302
- description=tool_def.description,
303
- parameters={},
304
- )
305
-
306
-
307
283
  def create_tool_prompt_node(
308
284
  ml_model: str,
309
285
  blocks: List[Union[PromptBlock, Dict[str, Any]]],
@@ -313,17 +289,10 @@ def create_tool_prompt_node(
313
289
  max_prompt_iterations: Optional[int] = None,
314
290
  process_parameters_method: Optional[Callable] = None,
315
291
  process_blocks_method: Optional[Callable] = None,
292
+ settings: Optional[Union[PromptSettings, Dict[str, Any]]] = None,
316
293
  ) -> Type[ToolPromptNode]:
317
294
  if functions and len(functions) > 0:
318
- prompt_functions: List[Union[Tool, FunctionDefinition]] = []
319
-
320
- for function in functions:
321
- if isinstance(function, ComposioToolDefinition):
322
- # Get Composio tool details and hydrate the function definition
323
- enhanced_function = _hydrate_composio_tool_definition(function)
324
- prompt_functions.append(enhanced_function)
325
- else:
326
- prompt_functions.append(function)
295
+ prompt_functions: List[Tool] = functions
327
296
  else:
328
297
  prompt_functions = []
329
298
 
@@ -359,6 +328,13 @@ def create_tool_prompt_node(
359
328
  ),
360
329
  }
361
330
 
331
+ # Normalize settings to PromptSettings if provided as a dict
332
+ normalized_settings: Optional[PromptSettings]
333
+ if isinstance(settings, dict):
334
+ normalized_settings = PromptSettings.model_validate(settings)
335
+ else:
336
+ normalized_settings = settings
337
+
362
338
  node = cast(
363
339
  Type[ToolPromptNode],
364
340
  type(
@@ -371,6 +347,7 @@ def create_tool_prompt_node(
371
347
  "prompt_inputs": node_prompt_inputs,
372
348
  "parameters": parameters,
373
349
  "max_prompt_iterations": max_prompt_iterations,
350
+ "settings": normalized_settings,
374
351
  **({"process_parameters": process_parameters_method} if process_parameters_method is not None else {}),
375
352
  **({"process_blocks": process_blocks_method} if process_blocks_method is not None else {}),
376
353
  "__module__": __name__,
@@ -409,6 +386,10 @@ def create_router_node(
409
386
  function_name = get_function_name(function)
410
387
  port = create_port_condition(function_name)
411
388
  setattr(Ports, function_name, port)
389
+ elif isinstance(function, VellumIntegrationToolDefinition):
390
+ function_name = get_function_name(function)
391
+ port = create_port_condition(function_name)
392
+ setattr(Ports, function_name, port)
412
393
  elif isinstance(function, MCPServer):
413
394
  tool_functions: List[MCPToolDefinition] = compile_mcp_tool_definition(function)
414
395
  for tool_function in tool_functions:
@@ -483,6 +464,12 @@ def create_function_node(
483
464
  },
484
465
  )
485
466
  return node
467
+ elif isinstance(function, VellumIntegrationToolDefinition):
468
+ # TODO: Implement VellumIntegrationNode
469
+ raise NotImplementedError(
470
+ "VellumIntegrationToolDefinition support coming soon. "
471
+ "This will be implemented when the VellumIntegrationService is created."
472
+ )
486
473
  elif is_workflow_class(function):
487
474
  function.is_dynamic = True
488
475
  node = type(
@@ -572,5 +559,7 @@ def get_function_name(function: ToolBase) -> str:
572
559
  elif isinstance(function, ComposioToolDefinition):
573
560
  # model post init sets the name to the action if it's not set
574
561
  return function.name # type: ignore[return-value]
562
+ elif isinstance(function, VellumIntegrationToolDefinition):
563
+ return function.name
575
564
  else:
576
565
  return snake_case(function.__name__)