vellum-ai 1.0.10__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. vellum/__init__.py +2 -2
  2. vellum/client/__init__.py +0 -4
  3. vellum/client/core/client_wrapper.py +2 -2
  4. vellum/client/reference.md +2 -3
  5. vellum/client/resources/__init__.py +0 -2
  6. vellum/client/resources/workflow_deployments/client.py +119 -0
  7. vellum/client/types/__init__.py +2 -0
  8. vellum/client/types/api_request_parent_context.py +1 -0
  9. vellum/client/types/external_parent_context.py +36 -0
  10. vellum/client/types/node_execution_fulfilled_event.py +1 -0
  11. vellum/client/types/node_execution_initiated_event.py +1 -0
  12. vellum/client/types/node_execution_paused_event.py +1 -0
  13. vellum/client/types/node_execution_rejected_event.py +1 -0
  14. vellum/client/types/node_execution_resumed_event.py +1 -0
  15. vellum/client/types/node_execution_span.py +1 -0
  16. vellum/client/types/node_execution_span_attributes.py +1 -0
  17. vellum/client/types/node_execution_streaming_event.py +1 -0
  18. vellum/client/types/node_parent_context.py +1 -0
  19. vellum/client/types/parent_context.py +2 -0
  20. vellum/client/types/prompt_deployment_parent_context.py +1 -0
  21. vellum/client/types/slim_workflow_execution_read.py +1 -0
  22. vellum/client/types/span_link.py +1 -0
  23. vellum/client/types/workflow_deployment_event_executions_response.py +1 -0
  24. vellum/client/types/workflow_deployment_parent_context.py +1 -0
  25. vellum/client/types/workflow_event_execution_read.py +1 -0
  26. vellum/client/types/workflow_execution_detail.py +1 -0
  27. vellum/client/types/workflow_execution_fulfilled_event.py +1 -0
  28. vellum/client/types/workflow_execution_initiated_event.py +1 -0
  29. vellum/client/types/workflow_execution_paused_event.py +1 -0
  30. vellum/client/types/workflow_execution_rejected_event.py +1 -0
  31. vellum/client/types/workflow_execution_resumed_event.py +1 -0
  32. vellum/client/types/workflow_execution_snapshotted_event.py +1 -0
  33. vellum/client/types/workflow_execution_span.py +1 -0
  34. vellum/client/types/workflow_execution_span_attributes.py +1 -0
  35. vellum/client/types/workflow_execution_streaming_event.py +1 -0
  36. vellum/client/types/workflow_parent_context.py +1 -0
  37. vellum/client/types/workflow_sandbox_parent_context.py +1 -0
  38. vellum/{resources/release_reviews/__init__.py → types/external_parent_context.py} +1 -1
  39. vellum/workflows/descriptors/base.py +31 -1
  40. vellum/workflows/descriptors/utils.py +19 -1
  41. vellum/workflows/emitters/vellum_emitter.py +3 -2
  42. vellum/workflows/events/types.py +6 -0
  43. vellum/workflows/expressions/accessor.py +23 -15
  44. vellum/workflows/expressions/add.py +41 -0
  45. vellum/workflows/expressions/length.py +35 -0
  46. vellum/workflows/expressions/minus.py +41 -0
  47. vellum/workflows/expressions/tests/test_add.py +72 -0
  48. vellum/workflows/expressions/tests/test_length.py +38 -0
  49. vellum/workflows/expressions/tests/test_minus.py +72 -0
  50. vellum/workflows/integrations/composio_service.py +4 -0
  51. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
  52. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +2 -2
  53. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +5 -15
  54. vellum/workflows/nodes/displayable/tool_calling_node/node.py +12 -1
  55. vellum/workflows/nodes/displayable/tool_calling_node/state.py +2 -0
  56. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_composio_service.py +49 -0
  57. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_node.py +3 -8
  58. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +167 -50
  59. vellum/workflows/state/context.py +13 -2
  60. vellum/workflows/types/definition.py +3 -8
  61. vellum/workflows/types/tests/test_definition.py +3 -4
  62. vellum/workflows/utils/functions.py +1 -1
  63. vellum/workflows/utils/tests/test_functions.py +3 -3
  64. {vellum_ai-1.0.10.dist-info → vellum_ai-1.1.0.dist-info}/METADATA +1 -1
  65. {vellum_ai-1.0.10.dist-info → vellum_ai-1.1.0.dist-info}/RECORD +73 -68
  66. vellum_ee/workflows/display/nodes/vellum/tests/test_tool_calling_node.py +93 -0
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_composio_serialization.py +0 -4
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_mcp_serialization.py +98 -0
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_workflow_deployment_serialization.py +1 -1
  70. vellum_ee/workflows/display/utils/expressions.py +13 -1
  71. vellum/client/resources/release_reviews/__init__.py +0 -2
  72. vellum/client/resources/release_reviews/client.py +0 -139
  73. vellum/resources/release_reviews/client.py +0 -3
  74. {vellum_ai-1.0.10.dist-info → vellum_ai-1.1.0.dist-info}/LICENSE +0 -0
  75. {vellum_ai-1.0.10.dist-info → vellum_ai-1.1.0.dist-info}/WHEEL +0 -0
  76. {vellum_ai-1.0.10.dist-info → vellum_ai-1.1.0.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  from collections.abc import Mapping
2
2
  import dataclasses
3
3
  import inspect
4
- from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, overload
4
+ from typing import Any, Dict, Optional, Sequence, Set, Type, TypeVar, Union, cast, overload
5
+ from typing_extensions import TypeGuard
5
6
 
6
7
  from pydantic import BaseModel
7
8
 
@@ -115,3 +116,20 @@ def is_unresolved(value: Any) -> bool:
115
116
  return any(is_unresolved(item) for item in value)
116
117
 
117
118
  return False
119
+
120
+
121
+ _ResolvedType = TypeVar("_ResolvedType")
122
+
123
+
124
+ def is_resolved_instance(value: Any, type_: Type[_ResolvedType]) -> TypeGuard[_ResolvedType]:
125
+ """
126
+ Checks if a value is an instance of a type or a descriptor that resolves to that type.
127
+ """
128
+
129
+ if isinstance(value, type_):
130
+ return True
131
+
132
+ if isinstance(value, BaseDescriptor) and value.types:
133
+ return issubclass(value.types[0], type_)
134
+
135
+ return False
@@ -41,7 +41,7 @@ class VellumEmitter(BaseWorkflowEmitter):
41
41
  super().__init__()
42
42
  self._timeout = timeout
43
43
  self._max_retries = max_retries
44
- self._events_endpoint = "events" # TODO: make this configurable with the correct url
44
+ self._events_endpoint = "v1/events" # TODO: make this configurable with the correct url
45
45
 
46
46
  def emit_event(self, event: WorkflowEvent) -> None:
47
47
  """
@@ -90,7 +90,8 @@ class VellumEmitter(BaseWorkflowEmitter):
90
90
  base_url = client._client_wrapper.get_environment().default
91
91
  response = client._client_wrapper.httpx_client.request(
92
92
  method="POST",
93
- path=f"{base_url}/{self._events_endpoint}", # TODO: will be replaced with the correct url
93
+ base_url=base_url,
94
+ path=self._events_endpoint, # TODO: will be replaced with the correct url
94
95
  json=event_data,
95
96
  headers=client._client_wrapper.get_headers(),
96
97
  request_options={"timeout_in_seconds": self._timeout},
@@ -80,6 +80,11 @@ class UnknownParentContext(BaseParentContext):
80
80
  type: Literal["UNKNOWN"] = "UNKNOWN"
81
81
 
82
82
 
83
+ # Setting external parent context for external workflows
84
+ class ExternalParentContext(BaseParentContext):
85
+ type: Literal["EXTERNAL"] = "EXTERNAL"
86
+
87
+
83
88
  def _cast_parent_context_discriminator(v: Any) -> Any:
84
89
  if v in PARENT_CONTEXT_TYPES:
85
90
  return v
@@ -138,6 +143,7 @@ ParentContext = Annotated[
138
143
  PromptDeploymentParentContext,
139
144
  WorkflowSandboxParentContext,
140
145
  APIRequestParentContext,
146
+ ExternalParentContext,
141
147
  UnknownParentContext,
142
148
  ],
143
149
  ParentContextDiscriminator(),
@@ -7,10 +7,11 @@ from pydantic_core import core_schema
7
7
 
8
8
  from vellum.workflows.descriptors.base import BaseDescriptor
9
9
  from vellum.workflows.descriptors.exceptions import InvalidExpressionException
10
- from vellum.workflows.descriptors.utils import resolve_value
10
+ from vellum.workflows.descriptors.utils import is_resolved_instance, resolve_value
11
11
  from vellum.workflows.state.base import BaseState
12
12
 
13
13
  LHS = TypeVar("LHS")
14
+ AccessorField = Union[str, int, BaseDescriptor[str], BaseDescriptor[int]]
14
15
 
15
16
 
16
17
  class AccessorExpression(BaseDescriptor[Any]):
@@ -18,7 +19,7 @@ class AccessorExpression(BaseDescriptor[Any]):
18
19
  self,
19
20
  *,
20
21
  base: BaseDescriptor[LHS],
21
- field: Union[str, int],
22
+ field: AccessorField,
22
23
  ) -> None:
23
24
  super().__init__(
24
25
  name=f"{base.name}.{field}",
@@ -28,7 +29,7 @@ class AccessorExpression(BaseDescriptor[Any]):
28
29
  self._base = base
29
30
  self._field = field
30
31
 
31
- def _infer_accessor_types(self, base: BaseDescriptor[LHS], field: Union[str, int]) -> tuple[Type, ...]:
32
+ def _infer_accessor_types(self, base: BaseDescriptor[LHS], field: AccessorField) -> tuple[Type, ...]:
32
33
  """
33
34
  Infer the types for this accessor expression based on the base descriptor's types
34
35
  and the field being accessed.
@@ -42,7 +43,7 @@ class AccessorExpression(BaseDescriptor[Any]):
42
43
  origin = get_origin(base_type)
43
44
  args = get_args(base_type)
44
45
 
45
- if isinstance(field, int) and origin in (list, tuple) and args:
46
+ if is_resolved_instance(field, int) and origin in (list, tuple) and args:
46
47
  if origin is list:
47
48
  inferred_types.append(args[0])
48
49
  elif origin is tuple and len(args) == 2 and args[1] is ...:
@@ -52,44 +53,51 @@ class AccessorExpression(BaseDescriptor[Any]):
52
53
  inferred_types.append(args[field])
53
54
  else:
54
55
  inferred_types.append(args[field])
55
- elif isinstance(field, str) and origin in (dict,) and len(args) >= 2:
56
+ elif is_resolved_instance(field, str) and origin in (dict,) and len(args) >= 2:
56
57
  inferred_types.append(args[1]) # Value type from Dict[K, V]
57
58
 
58
59
  return tuple(set(inferred_types)) if inferred_types else ()
59
60
 
60
61
  def resolve(self, state: "BaseState") -> Any:
61
62
  base = resolve_value(self._base, state)
63
+ accessor_field = resolve_value(self._field, state)
62
64
 
63
65
  if dataclasses.is_dataclass(base):
64
- if isinstance(self._field, int):
66
+ if isinstance(accessor_field, int):
65
67
  raise InvalidExpressionException("Cannot access field by index on a dataclass")
66
68
 
67
69
  try:
68
- return getattr(base, self._field)
70
+ return getattr(base, accessor_field)
69
71
  except AttributeError:
70
- raise InvalidExpressionException(f"Field '{self._field}' not found on dataclass {type(base).__name__}")
72
+ raise InvalidExpressionException(
73
+ f"Field '{accessor_field}' not found on dataclass {type(base).__name__}"
74
+ )
71
75
 
72
76
  if isinstance(base, BaseModel):
73
- if isinstance(self._field, int):
77
+ if isinstance(accessor_field, int):
74
78
  raise InvalidExpressionException("Cannot access field by index on a BaseModel")
75
79
 
76
80
  try:
77
- return getattr(base, self._field)
81
+ return getattr(base, accessor_field)
78
82
  except AttributeError:
79
- raise InvalidExpressionException(f"Field '{self._field}' not found on BaseModel {type(base).__name__}")
83
+ raise InvalidExpressionException(
84
+ f"Field '{accessor_field}' not found on BaseModel {type(base).__name__}"
85
+ )
80
86
 
81
87
  if isinstance(base, Mapping):
82
88
  try:
83
- return base[self._field]
89
+ return base[accessor_field]
84
90
  except KeyError:
85
- raise InvalidExpressionException(f"Key '{self._field}' not found in mapping")
91
+ raise InvalidExpressionException(f"Key '{accessor_field}' not found in mapping")
86
92
 
87
93
  if isinstance(base, Sequence):
88
94
  try:
89
- index = int(self._field)
95
+ index = int(accessor_field)
90
96
  return base[index]
91
97
  except (IndexError, ValueError):
92
- if isinstance(self._field, int) or (isinstance(self._field, str) and self._field.lstrip("-").isdigit()):
98
+ if isinstance(accessor_field, int) or (
99
+ isinstance(accessor_field, str) and accessor_field.lstrip("-").isdigit()
100
+ ):
93
101
  raise InvalidExpressionException(
94
102
  f"Index {self._field} is out of bounds for sequence of length {len(base)}"
95
103
  )
@@ -0,0 +1,41 @@
1
+ from typing import Any, Generic, Protocol, TypeVar, Union, runtime_checkable
2
+ from typing_extensions import TypeGuard
3
+
4
+ from vellum.workflows.descriptors.base import BaseDescriptor
5
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
6
+ from vellum.workflows.descriptors.utils import resolve_value
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+
10
+ @runtime_checkable
11
+ class SupportsAdd(Protocol):
12
+ def __add__(self, other: Any) -> Any: ...
13
+
14
+
15
+ def has_add(obj: Any) -> TypeGuard[SupportsAdd]:
16
+ return hasattr(obj, "__add__")
17
+
18
+
19
+ LHS = TypeVar("LHS")
20
+ RHS = TypeVar("RHS")
21
+
22
+
23
+ class AddExpression(BaseDescriptor[Any], Generic[LHS, RHS]):
24
+ def __init__(
25
+ self,
26
+ *,
27
+ lhs: Union[BaseDescriptor[LHS], LHS],
28
+ rhs: Union[BaseDescriptor[RHS], RHS],
29
+ ) -> None:
30
+ super().__init__(name=f"{lhs} + {rhs}", types=(object,))
31
+ self._lhs = lhs
32
+ self._rhs = rhs
33
+
34
+ def resolve(self, state: "BaseState") -> Any:
35
+ lhs = resolve_value(self._lhs, state)
36
+ rhs = resolve_value(self._rhs, state)
37
+
38
+ if not has_add(lhs):
39
+ raise InvalidExpressionException(f"'{lhs.__class__.__name__}' must support the '+' operator")
40
+
41
+ return lhs + rhs
@@ -0,0 +1,35 @@
1
+ from typing import Generic, TypeVar, Union
2
+
3
+ from vellum.workflows.constants import undefined
4
+ from vellum.workflows.descriptors.base import BaseDescriptor
5
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
6
+ from vellum.workflows.descriptors.utils import resolve_value
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+ _T = TypeVar("_T")
10
+
11
+
12
+ class LengthExpression(BaseDescriptor[int], Generic[_T]):
13
+ def __init__(
14
+ self,
15
+ *,
16
+ expression: Union[BaseDescriptor[_T], _T],
17
+ ) -> None:
18
+ super().__init__(name=f"length({expression})", types=(int,))
19
+ self._expression = expression
20
+
21
+ def resolve(self, state: "BaseState") -> int:
22
+ expression = resolve_value(self._expression, state)
23
+
24
+ if expression is undefined:
25
+ raise InvalidExpressionException("Cannot get length of undefined value")
26
+
27
+ if not hasattr(expression, "__len__"):
28
+ raise InvalidExpressionException(
29
+ f"Expected an object that supports `len()`, got `{expression.__class__.__name__}`"
30
+ )
31
+
32
+ try:
33
+ return len(expression)
34
+ except TypeError as e:
35
+ raise InvalidExpressionException(f"Cannot get length of `{expression.__class__.__name__}`: {str(e)}")
@@ -0,0 +1,41 @@
1
+ from typing import Any, Generic, Protocol, TypeVar, Union, runtime_checkable
2
+ from typing_extensions import TypeGuard
3
+
4
+ from vellum.workflows.descriptors.base import BaseDescriptor
5
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
6
+ from vellum.workflows.descriptors.utils import resolve_value
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+
10
+ @runtime_checkable
11
+ class SupportsMinus(Protocol):
12
+ def __sub__(self, other: Any) -> Any: ...
13
+
14
+
15
+ def has_sub(obj: Any) -> TypeGuard[SupportsMinus]:
16
+ return hasattr(obj, "__sub__")
17
+
18
+
19
+ LHS = TypeVar("LHS")
20
+ RHS = TypeVar("RHS")
21
+
22
+
23
+ class MinusExpression(BaseDescriptor[Any], Generic[LHS, RHS]):
24
+ def __init__(
25
+ self,
26
+ *,
27
+ lhs: Union[BaseDescriptor[LHS], LHS],
28
+ rhs: Union[BaseDescriptor[RHS], RHS],
29
+ ) -> None:
30
+ super().__init__(name=f"{lhs} - {rhs}", types=(object,))
31
+ self._lhs = lhs
32
+ self._rhs = rhs
33
+
34
+ def resolve(self, state: "BaseState") -> Any:
35
+ lhs = resolve_value(self._lhs, state)
36
+ rhs = resolve_value(self._rhs, state)
37
+
38
+ if not has_sub(lhs):
39
+ raise InvalidExpressionException(f"'{lhs.__class__.__name__}' must support the '-' operator")
40
+
41
+ return lhs - rhs
@@ -0,0 +1,72 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
4
+ from vellum.workflows.expressions.add import AddExpression
5
+ from vellum.workflows.state.base import BaseState
6
+
7
+
8
+ class TestState(BaseState):
9
+ number_value: int = 5
10
+ string_value: str = "hello"
11
+
12
+
13
+ def test_add_expression_numbers():
14
+ """
15
+ Tests that AddExpression correctly adds two numbers.
16
+ """
17
+
18
+ state = TestState()
19
+
20
+ expression = TestState.number_value.add(10)
21
+
22
+ result = expression.resolve(state)
23
+ assert result == 15
24
+
25
+
26
+ def test_add_expression_strings():
27
+ """
28
+ Tests that AddExpression correctly concatenates two strings.
29
+ """
30
+
31
+ state = TestState()
32
+
33
+ expression = TestState.string_value.add(" world")
34
+
35
+ result = expression.resolve(state)
36
+ assert result == "hello world"
37
+
38
+
39
+ def test_add_expression_unsupported_type():
40
+ """
41
+ Tests that AddExpression raises an exception for unsupported types.
42
+ """
43
+
44
+ class NoAddSupport:
45
+ pass
46
+
47
+ no_add_obj = NoAddSupport()
48
+ expression = AddExpression(lhs=no_add_obj, rhs=5)
49
+ state = TestState()
50
+
51
+ with pytest.raises(InvalidExpressionException, match="'NoAddSupport' must support the '\\+' operator"):
52
+ expression.resolve(state)
53
+
54
+
55
+ def test_add_expression_name():
56
+ """
57
+ Tests that AddExpression has the correct name.
58
+ """
59
+
60
+ expression = AddExpression(lhs=5, rhs=3)
61
+
62
+ assert expression.name == "5 + 3"
63
+
64
+
65
+ def test_add_expression_types():
66
+ """
67
+ Tests that AddExpression has the correct types.
68
+ """
69
+
70
+ expression = AddExpression(lhs=5, rhs=3)
71
+
72
+ assert expression.types == (object,)
@@ -0,0 +1,38 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.constants import undefined
4
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
5
+ from vellum.workflows.expressions.length import LengthExpression
6
+ from vellum.workflows.state.base import BaseState
7
+
8
+
9
+ class TestState(BaseState):
10
+ string_value: str = "hello world"
11
+
12
+
13
+ def test_length_expression_string():
14
+ """
15
+ Tests that LengthExpression correctly returns the length of a string.
16
+ """
17
+
18
+ state = TestState()
19
+
20
+ expression = TestState.string_value.length()
21
+ result = expression.resolve(state)
22
+
23
+ assert result == 11
24
+
25
+
26
+ def test_length_expression_undefined():
27
+ """
28
+ Tests that LengthExpression raises an exception for undefined values.
29
+ """
30
+
31
+ expression = LengthExpression(expression=undefined)
32
+ state = TestState()
33
+
34
+ # THEN we should get an InvalidExpressionException
35
+ with pytest.raises(InvalidExpressionException) as exc_info:
36
+ expression.resolve(state)
37
+
38
+ assert "Cannot get length of undefined value" in str(exc_info.value)
@@ -0,0 +1,72 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
4
+ from vellum.workflows.expressions.minus import MinusExpression
5
+ from vellum.workflows.state.base import BaseState
6
+
7
+
8
+ class TestState(BaseState):
9
+ number_value: int = 10
10
+ float_value: float = 15.5
11
+
12
+
13
+ def test_minus_expression_numbers():
14
+ """
15
+ Tests that MinusExpression correctly subtracts two numbers.
16
+ """
17
+
18
+ state = TestState()
19
+
20
+ expression = TestState.number_value.minus(3)
21
+
22
+ result = expression.resolve(state)
23
+ assert result == 7
24
+
25
+
26
+ def test_minus_expression_floats():
27
+ """
28
+ Tests that MinusExpression correctly subtracts two floats.
29
+ """
30
+
31
+ state = TestState()
32
+
33
+ expression = TestState.float_value.minus(5.5)
34
+
35
+ result = expression.resolve(state)
36
+ assert result == 10.0
37
+
38
+
39
+ def test_minus_expression_unsupported_type():
40
+ """
41
+ Tests that MinusExpression raises an exception for unsupported types.
42
+ """
43
+
44
+ class NoSubSupport:
45
+ pass
46
+
47
+ no_sub_obj = NoSubSupport()
48
+ expression = MinusExpression(lhs=no_sub_obj, rhs=5)
49
+ state = TestState()
50
+
51
+ with pytest.raises(InvalidExpressionException, match="'NoSubSupport' must support the '-' operator"):
52
+ expression.resolve(state)
53
+
54
+
55
+ def test_minus_expression_name():
56
+ """
57
+ Tests that MinusExpression has the correct name.
58
+ """
59
+
60
+ expression = MinusExpression(lhs=10, rhs=3)
61
+
62
+ assert expression.name == "10 - 3"
63
+
64
+
65
+ def test_minus_expression_types():
66
+ """
67
+ Tests that MinusExpression has the correct types.
68
+ """
69
+
70
+ expression = MinusExpression(lhs=10, rhs=3)
71
+
72
+ assert expression.types == (object,)
@@ -151,4 +151,8 @@ class ComposioService:
151
151
  if user_id is not None:
152
152
  json_data["user_id"] = user_id
153
153
  response = self._make_request(endpoint, method="POST", json_data=json_data)
154
+
155
+ if not response.get("successful", True):
156
+ return response.get("error", "Tool execution failed")
157
+
154
158
  return response.get("data", response)
@@ -14,7 +14,7 @@ from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
14
14
  from vellum.workflows.types.generics import StateType
15
15
 
16
16
 
17
- class BasePromptNode(BaseNode, Generic[StateType]):
17
+ class BasePromptNode(BaseNode[StateType], Generic[StateType]):
18
18
  # Inputs that are passed to the Prompt
19
19
  prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
20
20
 
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, Iterator, Type, Union
2
+ from typing import Any, Dict, Generic, Iterator, Type, Union
3
3
 
4
4
  from vellum.workflows.constants import undefined
5
5
  from vellum.workflows.errors import WorkflowErrorCode
@@ -10,7 +10,7 @@ from vellum.workflows.types import MergeBehavior
10
10
  from vellum.workflows.types.generics import StateType
11
11
 
12
12
 
13
- class InlinePromptNode(BaseInlinePromptNode[StateType]):
13
+ class InlinePromptNode(BaseInlinePromptNode[StateType], Generic[StateType]):
14
14
  """
15
15
  Used to execute a Prompt defined inline.
16
16
 
@@ -1,4 +1,3 @@
1
- from unittest import mock
2
1
  from uuid import uuid4
3
2
  from typing import Any, Iterator, List
4
3
 
@@ -64,17 +63,8 @@ def test_text_prompt_deployment_node__basic(vellum_client):
64
63
  assert text_output.value == "Hello, world!"
65
64
 
66
65
  # AND we should have made the expected call to stream the prompt execution
67
- vellum_client.execute_prompt_stream.assert_called_once_with(
68
- expand_meta=None,
69
- expand_raw=None,
70
- external_id=None,
71
- inputs=[],
72
- metadata=None,
73
- prompt_deployment_id=None,
74
- prompt_deployment_name="my-deployment",
75
- raw_overrides=None,
76
- release_tag="LATEST",
77
- request_options={
78
- "additional_body_parameters": {"execution_context": {"parent_context": None, "trace_id": mock.ANY}}
79
- },
80
- )
66
+ vellum_client.execute_prompt_stream.assert_called_once()
67
+ _, call_kwargs = vellum_client.execute_prompt_stream.call_args
68
+ exec_ctx = call_kwargs["request_options"]["additional_body_parameters"]["execution_context"]
69
+ assert exec_ctx["parent_context"] is not None
70
+ assert exec_ctx["parent_context"]["type"] == "EXTERNAL"
@@ -12,8 +12,10 @@ from vellum.workflows.inputs.base import BaseInputs
12
12
  from vellum.workflows.nodes.bases import BaseNode
13
13
  from vellum.workflows.nodes.displayable.tool_calling_node.state import ToolCallingState
14
14
  from vellum.workflows.nodes.displayable.tool_calling_node.utils import (
15
+ create_else_node,
15
16
  create_function_node,
16
17
  create_mcp_tool_node,
18
+ create_router_node,
17
19
  create_tool_router_node,
18
20
  get_function_name,
19
21
  get_mcp_tool_name,
@@ -144,6 +146,11 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
144
146
  max_prompt_iterations=self.max_prompt_iterations,
145
147
  )
146
148
 
149
+ self.router_node = create_router_node(
150
+ functions=self.functions,
151
+ tool_router_node=self.tool_router_node,
152
+ )
153
+
147
154
  self._function_nodes = {}
148
155
  for function in self.functions:
149
156
  if isinstance(function, MCPServer):
@@ -171,7 +178,11 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
171
178
  edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
172
179
  graph_set.add(edge_graph)
173
180
 
174
- default_port = getattr(self.tool_router_node.Ports, "default")
181
+ else_node = create_else_node(self.tool_router_node)
182
+ default_port = self.tool_router_node.Ports.default >> {
183
+ else_node.Ports.loop >> self.tool_router_node,
184
+ else_node.Ports.end,
185
+ }
175
186
  graph_set.add(default_port)
176
187
 
177
188
  self._graph = Graph.from_set(graph_set)
@@ -7,3 +7,5 @@ from vellum.workflows.state.base import BaseState
7
7
  class ToolCallingState(BaseState):
8
8
  chat_history: List[ChatMessage] = []
9
9
  prompt_iterations: int = 0
10
+ current_prompt_output_index: int = 0
11
+ current_function_calls_processed: int = 0
@@ -44,6 +44,21 @@ def mock_tool_execution_response():
44
44
  }
45
45
 
46
46
 
47
+ @pytest.fixture
48
+ def mock_tool_execution_error_response():
49
+ """Mock response for failed tool execution API"""
50
+ return {
51
+ "data": {},
52
+ "successful": False,
53
+ "error": (
54
+ 'Request failed error: `{"message":"Not Found",'
55
+ '"documentation_url":"https://docs.github.com/rest/pulls/pulls#get-a-pull-request",'
56
+ '"status":"404"}`'
57
+ ),
58
+ "log_id": "log_raE_fIWNcDPo",
59
+ }
60
+
61
+
47
62
  @pytest.fixture
48
63
  def composio_service():
49
64
  """Create ComposioService with test API key"""
@@ -168,3 +183,37 @@ class TestComposioCoreService:
168
183
  timeout=30,
169
184
  )
170
185
  assert result == {"items": [], "total": 0}
186
+
187
+ def test_execute_tool_failure_surfaces_error(
188
+ self, composio_service, mock_requests, mock_tool_execution_error_response
189
+ ):
190
+ """Test that tool execution failures surface detailed error information"""
191
+ # GIVEN a mock response indicating tool execution failure
192
+ mock_response = Mock()
193
+ mock_response.json.return_value = mock_tool_execution_error_response
194
+ mock_response.raise_for_status.return_value = None
195
+ mock_requests.post.return_value = mock_response
196
+
197
+ # WHEN we execute a tool that fails
198
+ result = composio_service.execute_tool("GITHUB_GET_PR", {"repo": "test", "pr_number": 999})
199
+
200
+ # THEN the result should contain the detailed error message from the API
201
+ assert "Request failed error" in result
202
+ assert "Not Found" in result
203
+
204
+ def test_execute_tool_failure_with_generic_error_message(self, composio_service, mock_requests):
205
+ """Test that tool execution failures with missing error field use generic message"""
206
+ # GIVEN a mock response indicating tool execution failure without error field
207
+ mock_response = Mock()
208
+ mock_response.json.return_value = {
209
+ "data": {},
210
+ "successful": False,
211
+ }
212
+ mock_response.raise_for_status.return_value = None
213
+ mock_requests.post.return_value = mock_response
214
+
215
+ # WHEN we execute a tool that fails
216
+ result = composio_service.execute_tool("TEST_TOOL", {"param": "value"})
217
+
218
+ # THEN the result should contain the generic error message
219
+ assert result == "Tool execution failed"