vellum-ai 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. vellum/__init__.py +16 -0
  2. vellum/client/README.md +55 -0
  3. vellum/client/__init__.py +66 -507
  4. vellum/client/core/client_wrapper.py +2 -2
  5. vellum/client/core/pydantic_utilities.py +10 -3
  6. vellum/client/raw_client.py +844 -0
  7. vellum/client/reference.md +692 -19
  8. vellum/client/resources/ad_hoc/client.py +23 -180
  9. vellum/client/resources/ad_hoc/raw_client.py +276 -0
  10. vellum/client/resources/container_images/client.py +10 -36
  11. vellum/client/resources/deployments/client.py +16 -62
  12. vellum/client/resources/document_indexes/client.py +16 -72
  13. vellum/client/resources/documents/client.py +8 -30
  14. vellum/client/resources/folder_entities/client.py +4 -8
  15. vellum/client/resources/metric_definitions/client.py +4 -14
  16. vellum/client/resources/ml_models/client.py +2 -8
  17. vellum/client/resources/organizations/client.py +2 -6
  18. vellum/client/resources/prompts/client.py +2 -10
  19. vellum/client/resources/sandboxes/client.py +4 -20
  20. vellum/client/resources/test_suite_runs/client.py +4 -18
  21. vellum/client/resources/test_suites/client.py +11 -86
  22. vellum/client/resources/test_suites/raw_client.py +136 -0
  23. vellum/client/resources/workflow_deployments/client.py +20 -78
  24. vellum/client/resources/workflow_executions/client.py +2 -6
  25. vellum/client/resources/workflow_sandboxes/client.py +2 -10
  26. vellum/client/resources/workflows/client.py +7 -6
  27. vellum/client/resources/workflows/raw_client.py +58 -47
  28. vellum/client/resources/workspace_secrets/client.py +4 -20
  29. vellum/client/resources/workspaces/client.py +2 -6
  30. vellum/client/types/__init__.py +16 -0
  31. vellum/client/types/array_chat_message_content_item.py +4 -2
  32. vellum/client/types/array_chat_message_content_item_request.py +4 -2
  33. vellum/client/types/chat_message_content.py +4 -2
  34. vellum/client/types/chat_message_content_request.py +4 -2
  35. vellum/client/types/node_execution_span.py +2 -0
  36. vellum/client/types/prompt_block.py +4 -2
  37. vellum/client/types/vellum_value.py +4 -2
  38. vellum/client/types/vellum_value_request.py +4 -2
  39. vellum/client/types/vellum_variable_type.py +2 -1
  40. vellum/client/types/vellum_video.py +24 -0
  41. vellum/client/types/vellum_video_request.py +24 -0
  42. vellum/client/types/video_chat_message_content.py +25 -0
  43. vellum/client/types/video_chat_message_content_request.py +25 -0
  44. vellum/client/types/video_prompt_block.py +29 -0
  45. vellum/client/types/video_vellum_value.py +25 -0
  46. vellum/client/types/video_vellum_value_request.py +25 -0
  47. vellum/client/types/workflow_execution_span.py +2 -0
  48. vellum/client/types/workflow_execution_usage_calculation_fulfilled_body.py +22 -0
  49. vellum/prompts/blocks/compilation.py +22 -10
  50. vellum/types/vellum_video.py +3 -0
  51. vellum/types/vellum_video_request.py +3 -0
  52. vellum/types/video_chat_message_content.py +3 -0
  53. vellum/types/video_chat_message_content_request.py +3 -0
  54. vellum/types/video_prompt_block.py +3 -0
  55. vellum/types/video_vellum_value.py +3 -0
  56. vellum/types/video_vellum_value_request.py +3 -0
  57. vellum/types/workflow_execution_usage_calculation_fulfilled_body.py +3 -0
  58. vellum/workflows/events/workflow.py +11 -0
  59. vellum/workflows/graph/graph.py +103 -1
  60. vellum/workflows/graph/tests/test_graph.py +99 -0
  61. vellum/workflows/nodes/bases/base.py +9 -1
  62. vellum/workflows/nodes/displayable/bases/utils.py +4 -2
  63. vellum/workflows/nodes/displayable/tool_calling_node/node.py +19 -18
  64. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_node.py +17 -7
  65. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +7 -7
  66. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +47 -80
  67. vellum/workflows/references/environment_variable.py +10 -0
  68. vellum/workflows/runner/runner.py +18 -2
  69. vellum/workflows/state/context.py +101 -12
  70. vellum/workflows/types/definition.py +11 -1
  71. vellum/workflows/types/tests/test_definition.py +19 -0
  72. vellum/workflows/utils/vellum_variables.py +9 -5
  73. vellum/workflows/workflows/base.py +12 -5
  74. {vellum_ai-1.1.1.dist-info → vellum_ai-1.1.3.dist-info}/METADATA +1 -1
  75. {vellum_ai-1.1.1.dist-info → vellum_ai-1.1.3.dist-info}/RECORD +85 -69
  76. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -1
  77. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +55 -1
  78. vellum_ee/workflows/display/nodes/vellum/tests/test_tool_calling_node.py +15 -52
  79. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_mcp_serialization.py +15 -49
  80. vellum_ee/workflows/display/types.py +14 -1
  81. vellum_ee/workflows/display/utils/expressions.py +13 -4
  82. vellum_ee/workflows/display/workflows/base_workflow_display.py +6 -19
  83. {vellum_ai-1.1.1.dist-info → vellum_ai-1.1.3.dist-info}/LICENSE +0 -0
  84. {vellum_ai-1.1.1.dist-info → vellum_ai-1.1.3.dist-info}/WHEEL +0 -0
  85. {vellum_ai-1.1.1.dist-info → vellum_ai-1.1.3.dist-info}/entry_points.txt +0 -0
@@ -12,6 +12,7 @@ from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
12
12
  import typing
13
13
  from .vellum_workflow_execution_event import VellumWorkflowExecutionEvent
14
14
  from .workflow_execution_span_attributes import WorkflowExecutionSpanAttributes
15
+ from .workflow_execution_usage_calculation_fulfilled_body import WorkflowExecutionUsageCalculationFulfilledBody
15
16
  import datetime as dt
16
17
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
17
18
  import pydantic
@@ -21,6 +22,7 @@ class WorkflowExecutionSpan(UniversalBaseModel):
21
22
  name: typing.Literal["workflow.execution"] = "workflow.execution"
22
23
  events: typing.List[VellumWorkflowExecutionEvent]
23
24
  attributes: WorkflowExecutionSpanAttributes
25
+ usage_result: typing.Optional[WorkflowExecutionUsageCalculationFulfilledBody] = None
24
26
  span_id: str
25
27
  start_ts: dt.datetime
26
28
  end_ts: dt.datetime
@@ -0,0 +1,22 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import typing
5
+ from .ml_model_usage_wrapper import MlModelUsageWrapper
6
+ from .price import Price
7
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
+ import pydantic
9
+
10
+
11
+ class WorkflowExecutionUsageCalculationFulfilledBody(UniversalBaseModel):
12
+ usage: typing.List[MlModelUsageWrapper]
13
+ cost: typing.List[Price]
14
+
15
+ if IS_PYDANTIC_V2:
16
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
17
+ else:
18
+
19
+ class Config:
20
+ frozen = True
21
+ smart_union = True
22
+ extra = pydantic.Extra.allow
@@ -142,29 +142,41 @@ def compile_prompt_blocks(
142
142
  )
143
143
  compiled_blocks.append(function_call_block)
144
144
 
145
- elif block.block_type == "IMAGE":
146
- image_block = CompiledValuePromptBlock(
147
- content=ImageVellumValue(
148
- value=VellumImage(
145
+ elif block.block_type == "AUDIO":
146
+ audio_block = CompiledValuePromptBlock(
147
+ content=AudioVellumValue(
148
+ value=VellumAudio(
149
149
  src=block.src,
150
150
  metadata=block.metadata,
151
151
  ),
152
152
  ),
153
153
  cache_config=block.cache_config,
154
154
  )
155
- compiled_blocks.append(image_block)
155
+ compiled_blocks.append(audio_block)
156
156
 
157
- elif block.block_type == "AUDIO":
158
- audio_block = CompiledValuePromptBlock(
159
- content=AudioVellumValue(
160
- value=VellumAudio(
157
+ # elif block.block_type == "VIDEO":
158
+ # video_block = CompiledValuePromptBlock(
159
+ # content=VideoVellumValue(
160
+ # value=VellumVideo(
161
+ # src=block.src,
162
+ # metadata=block.metadata,
163
+ # ),
164
+ # ),
165
+ # cache_config=block.cache_config,
166
+ # )
167
+ # compiled_blocks.append(video_block)
168
+
169
+ elif block.block_type == "IMAGE":
170
+ image_block = CompiledValuePromptBlock(
171
+ content=ImageVellumValue(
172
+ value=VellumImage(
161
173
  src=block.src,
162
174
  metadata=block.metadata,
163
175
  ),
164
176
  ),
165
177
  cache_config=block.cache_config,
166
178
  )
167
- compiled_blocks.append(audio_block)
179
+ compiled_blocks.append(image_block)
168
180
 
169
181
  elif block.block_type == "DOCUMENT":
170
182
  document_block = CompiledValuePromptBlock(
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.vellum_video import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.vellum_video_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.video_chat_message_content import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.video_chat_message_content_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.video_prompt_block import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.video_vellum_value import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.video_vellum_value_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.workflow_execution_usage_calculation_fulfilled_body import *
@@ -41,6 +41,17 @@ class _BaseWorkflowEvent(BaseEvent):
41
41
  def workflow_definition(self) -> Type["BaseWorkflow"]:
42
42
  return self.body.workflow_definition
43
43
 
44
+ @property
45
+ def monitoring_url(self) -> Optional[str]:
46
+ """
47
+ Get the monitoring URL for this workflow execution.
48
+
49
+ Returns:
50
+ The URL to view execution details in Vellum UI, or None if monitoring is disabled
51
+ or context is not available.
52
+ """
53
+ return None
54
+
44
55
 
45
56
  class NodeEventDisplayContext(UniversalBaseModel):
46
57
  input_display: Dict[str, UUID]
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Iterator, List, Set, Type, Union
1
+ from typing import TYPE_CHECKING, Iterator, List, Optional, Set, Type, Union
2
2
 
3
3
  from orderly_set import OrderedSet
4
4
 
@@ -154,3 +154,105 @@ class Graph:
154
154
  for edge in edges:
155
155
  if edge not in self._edges:
156
156
  self._edges.append(edge)
157
+
158
+ def __str__(self) -> str:
159
+ """
160
+ Return a visual ASCII representation of the graph showing the flow structure.
161
+ """
162
+ if not self._edges and not self._entrypoints:
163
+ return "Graph(empty)"
164
+
165
+ if not self._edges:
166
+ if len(self._entrypoints) == 1:
167
+ port = next(iter(self._entrypoints))
168
+ return f"Graph: {self._get_port_name(port)}"
169
+ else:
170
+ port_names = [self._get_port_name(port) for port in self._entrypoints]
171
+ return f"Graph: [{', '.join(port_names)}]"
172
+
173
+ return self._build_flow_diagram()
174
+
175
+ def _build_flow_diagram(self) -> str:
176
+ """Build a connected flow diagram showing the graph structure."""
177
+ lines = ["Graph:"]
178
+
179
+ adjacency: dict[str, list[str]] = {}
180
+ all_nodes = set()
181
+
182
+ for edge in self._edges:
183
+ source_node = edge.from_port.node_class.__name__
184
+ target_node = edge.to_node.__name__
185
+
186
+ all_nodes.add(source_node)
187
+ all_nodes.add(target_node)
188
+
189
+ if source_node not in adjacency:
190
+ adjacency[source_node] = []
191
+ adjacency[source_node].append(target_node)
192
+
193
+ target_nodes = set()
194
+ for edges in adjacency.values():
195
+ target_nodes.update(edges)
196
+
197
+ root_nodes = []
198
+ for port in self._entrypoints:
199
+ node_name = port.node_class.__name__
200
+ if node_name not in target_nodes:
201
+ root_nodes.append(node_name)
202
+
203
+ if not root_nodes and self._entrypoints:
204
+ root_nodes = [next(iter(self._entrypoints)).node_class.__name__]
205
+
206
+ visited = set()
207
+ currently_visiting = set()
208
+
209
+ def render_node(node: str, prefix: str = " ", is_last: bool = True, path: Optional[List[str]] = None) -> None:
210
+ if path is None:
211
+ path = []
212
+
213
+ if node in currently_visiting:
214
+ lines.append(f"{prefix}{'└─' if is_last else '├─'} {node} ⟲ (loops back)")
215
+ return
216
+
217
+ if node in visited:
218
+ lines.append(f"{prefix}{'└─' if is_last else '├─'} {node} → (see above)")
219
+ return
220
+
221
+ visited.add(node)
222
+ currently_visiting.add(node)
223
+
224
+ lines.append(f"{prefix}{'└─' if is_last else '├─'} {node}")
225
+
226
+ if node in adjacency:
227
+ children = adjacency[node]
228
+ for i, child in enumerate(children):
229
+ child_is_last = i == len(children) - 1
230
+ next_prefix = prefix + (" " if is_last else "│ ")
231
+ render_node(child, next_prefix, child_is_last, path + [node])
232
+
233
+ currently_visiting.remove(node)
234
+
235
+ for i, root in enumerate(root_nodes):
236
+ is_last_root = i == len(root_nodes) - 1
237
+ render_node(root, " ", is_last_root)
238
+
239
+ return "\n".join(lines)
240
+
241
+ def _get_port_name(self, port: "Port") -> str:
242
+ """Get a readable name for a port."""
243
+ try:
244
+ if hasattr(port, "node_class") and hasattr(port.node_class, "__name__"):
245
+ node_name = port.node_class.__name__
246
+ port_name = getattr(port, "name", "unknown")
247
+ return f"{node_name}.{port_name}"
248
+ else:
249
+ return str(port)
250
+ except Exception:
251
+ return f"<Port:{getattr(port, 'name', 'unknown')}>"
252
+
253
+ def _get_node_name(self, node: Type["BaseNode"]) -> str:
254
+ """Get a readable name for a node."""
255
+ try:
256
+ return getattr(node, "__name__", str(node))
257
+ except Exception:
258
+ return "<Node:unknown>"
@@ -484,3 +484,102 @@ def test_graph__set_to_graph():
484
484
 
485
485
  # AND two edges
486
486
  assert len(list(graph.edges)) == 2
487
+
488
+
489
+ def test_graph__str_simple_linear():
490
+ # GIVEN a simple linear graph: A -> B -> C
491
+ class NodeA(BaseNode):
492
+ pass
493
+
494
+ class NodeB(BaseNode):
495
+ pass
496
+
497
+ class NodeC(BaseNode):
498
+ pass
499
+
500
+ graph = NodeA >> NodeB >> NodeC
501
+
502
+ # WHEN we convert the graph to string
503
+ result = str(graph)
504
+
505
+ # THEN it shows the linear flow structure
506
+ expected_lines = ["Graph:", " └─ NodeA", " └─ NodeB", " └─ NodeC"]
507
+ assert result == "\n".join(expected_lines)
508
+
509
+
510
+ def test_graph__str_with_branching():
511
+ # GIVEN a graph with branching: A -> {B, C}
512
+ class NodeA(BaseNode):
513
+ pass
514
+
515
+ class NodeB(BaseNode):
516
+ pass
517
+
518
+ class NodeC(BaseNode):
519
+ pass
520
+
521
+ graph = NodeA >> {NodeB, NodeC}
522
+
523
+ # WHEN we convert the graph to string
524
+ result = str(graph)
525
+
526
+ # THEN it shows the branching structure
527
+ # Note: The order might vary due to set ordering, so we check for the structure
528
+ lines = result.split("\n")
529
+ assert lines[0] == "Graph:"
530
+ assert lines[1] == " └─ NodeA"
531
+
532
+ # Should have two branches (order may vary)
533
+ branch_lines = [line.strip() for line in lines[2:] if line.strip()]
534
+ assert len(branch_lines) == 2
535
+ assert any("NodeB" in line for line in branch_lines)
536
+ assert any("NodeC" in line for line in branch_lines)
537
+ assert all(line.startswith("├─ ") or line.startswith("└─ ") for line in branch_lines)
538
+
539
+
540
+ def test_graph__str_with_loop():
541
+ # GIVEN a graph with a loop: A -> B -> A (loop back)
542
+ class NodeA(BaseNode):
543
+ pass
544
+
545
+ class NodeB(BaseNode):
546
+ pass
547
+
548
+ # Create the loop manually using edges
549
+ edge1 = Edge(NodeA.Ports.default, NodeB)
550
+ edge2 = Edge(NodeB.Ports.default, NodeA)
551
+
552
+ graph = Graph(entrypoints={NodeA.Ports.default}, edges=[edge1, edge2], terminals={NodeA.Ports.default})
553
+
554
+ # WHEN we convert the graph to string
555
+ result = str(graph)
556
+
557
+ # THEN it shows the loop with cycle detection
558
+ expected_lines = ["Graph:", " └─ NodeA", " └─ NodeB", " └─ NodeA ⟲ (loops back)"]
559
+ assert result == "\n".join(expected_lines)
560
+
561
+
562
+ def test_graph__str_empty_graph():
563
+ # GIVEN an empty graph
564
+ graph = Graph(entrypoints=set(), edges=[], terminals=set())
565
+
566
+ # WHEN we convert the graph to string
567
+ result = str(graph)
568
+
569
+ # THEN it shows empty graph message
570
+ assert result == "Graph(empty)"
571
+
572
+
573
+ def test_graph__str_single_node():
574
+ # GIVEN a graph with just one node and no edges
575
+ class SingleNode(BaseNode):
576
+ pass
577
+
578
+ graph = Graph.from_node(SingleNode)
579
+
580
+ # WHEN we convert the graph to string
581
+ result = str(graph)
582
+
583
+ # THEN it shows the single node
584
+ assert "SingleNode.default" in result
585
+ assert "Graph:" in result
@@ -8,6 +8,7 @@ from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, Typ
8
8
 
9
9
  from vellum.workflows.constants import undefined
10
10
  from vellum.workflows.descriptors.base import BaseDescriptor
11
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
11
12
  from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
12
13
  from vellum.workflows.errors.types import WorkflowErrorCode
13
14
  from vellum.workflows.events.node import NodeExecutionStreamingEvent
@@ -302,7 +303,14 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
302
303
  if not descriptor.instance:
303
304
  continue
304
305
 
305
- resolved_value = resolve_value(descriptor.instance, state, path=descriptor.name)
306
+ try:
307
+ resolved_value = resolve_value(descriptor.instance, state, path=descriptor.name)
308
+ except InvalidExpressionException as e:
309
+ raise NodeException(
310
+ message=str(e),
311
+ code=WorkflowErrorCode.INVALID_INPUTS,
312
+ ) from e
313
+
306
314
  if is_unresolved(resolved_value):
307
315
  return False
308
316
 
@@ -35,8 +35,9 @@ VELLUM_VALUE_REQUEST_TUPLE = (
35
35
  StringVellumValueRequest,
36
36
  NumberVellumValueRequest,
37
37
  JsonVellumValueRequest,
38
- ImageVellumValueRequest,
39
38
  AudioVellumValueRequest,
39
+ # VideoVellumValueRequest,
40
+ ImageVellumValueRequest,
40
41
  FunctionCallVellumValueRequest,
41
42
  ErrorVellumValueRequest,
42
43
  ArrayVellumValueRequest,
@@ -78,8 +79,9 @@ def primitive_to_vellum_value(value: Any) -> VellumValue:
78
79
  StringVellumValue,
79
80
  NumberVellumValue,
80
81
  JsonVellumValue,
81
- ImageVellumValue,
82
82
  AudioVellumValue,
83
+ # VideoVellumValue,
84
+ ImageVellumValue,
83
85
  FunctionCallVellumValue,
84
86
  ErrorVellumValue,
85
87
  ArrayVellumValue,
@@ -16,7 +16,7 @@ from vellum.workflows.nodes.displayable.tool_calling_node.utils import (
16
16
  create_function_node,
17
17
  create_mcp_tool_node,
18
18
  create_router_node,
19
- create_tool_router_node,
19
+ create_tool_prompt_node,
20
20
  get_function_name,
21
21
  get_mcp_tool_name,
22
22
  hydrate_mcp_tool_definitions,
@@ -78,7 +78,7 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
78
78
  graph = self._graph
79
79
 
80
80
  class Outputs(BaseWorkflow.Outputs):
81
- text: str = self.tool_router_node.Outputs.text
81
+ text: str = self.tool_prompt_node.Outputs.text
82
82
  chat_history: List[ChatMessage] = ToolCallingState.chat_history
83
83
 
84
84
  subworkflow = ToolCallingWorkflow(
@@ -137,7 +137,7 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
137
137
  )
138
138
 
139
139
  def _build_graph(self) -> None:
140
- self.tool_router_node = create_tool_router_node(
140
+ self.tool_prompt_node = create_tool_prompt_node(
141
141
  ml_model=self.ml_model,
142
142
  blocks=self.blocks,
143
143
  functions=self.functions,
@@ -146,9 +146,10 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
146
146
  max_prompt_iterations=self.max_prompt_iterations,
147
147
  )
148
148
 
149
+ # Create the router node (handles routing logic only)
149
150
  self.router_node = create_router_node(
150
151
  functions=self.functions,
151
- tool_router_node=self.tool_router_node,
152
+ tool_prompt_node=self.tool_prompt_node,
152
153
  )
153
154
 
154
155
  self._function_nodes = {}
@@ -160,29 +161,29 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
160
161
 
161
162
  self._function_nodes[function_name] = create_mcp_tool_node(
162
163
  tool_def=tool_definition,
163
- tool_router_node=self.tool_router_node,
164
+ tool_prompt_node=self.tool_prompt_node,
164
165
  )
165
166
  else:
166
167
  function_name = get_function_name(function)
167
168
 
168
169
  self._function_nodes[function_name] = create_function_node(
169
170
  function=function,
170
- tool_router_node=self.tool_router_node,
171
+ tool_prompt_node=self.tool_prompt_node,
171
172
  )
172
173
 
173
- graph_set = set()
174
+ graph: Graph = self.tool_prompt_node >> self.router_node
174
175
 
175
- # Add connections from ports of router to function nodes and back to router
176
176
  for function_name, FunctionNodeClass in self._function_nodes.items():
177
- router_port = getattr(self.tool_router_node.Ports, function_name)
178
- edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
179
- graph_set.add(edge_graph)
180
-
181
- else_node = create_else_node(self.tool_router_node)
182
- default_port = self.tool_router_node.Ports.default >> {
183
- else_node.Ports.loop >> self.tool_router_node,
184
- else_node.Ports.end,
177
+ router_port = getattr(self.router_node.Ports, function_name)
178
+ function_subgraph = router_port >> FunctionNodeClass >> self.router_node
179
+ graph._extend_edges(function_subgraph.edges)
180
+
181
+ else_node = create_else_node(self.tool_prompt_node)
182
+ default_port_graph = self.router_node.Ports.default >> {
183
+ else_node.Ports.loop_to_router >> self.router_node, # More outputs to process
184
+ else_node.Ports.loop_to_prompt >> self.tool_prompt_node, # Need new prompt iteration
185
+ else_node.Ports.end, # Finished
185
186
  }
186
- graph_set.add(default_port)
187
+ graph._extend_edges(default_port_graph.edges)
187
188
 
188
- self._graph = Graph.from_set(graph_set)
189
+ self._graph = graph
@@ -16,7 +16,11 @@ from vellum.workflows.inputs.base import BaseInputs
16
16
  from vellum.workflows.nodes.bases import BaseNode
17
17
  from vellum.workflows.nodes.displayable.tool_calling_node.node import ToolCallingNode
18
18
  from vellum.workflows.nodes.displayable.tool_calling_node.state import ToolCallingState
19
- from vellum.workflows.nodes.displayable.tool_calling_node.utils import create_function_node, create_tool_router_node
19
+ from vellum.workflows.nodes.displayable.tool_calling_node.utils import (
20
+ create_function_node,
21
+ create_router_node,
22
+ create_tool_prompt_node,
23
+ )
20
24
  from vellum.workflows.outputs.base import BaseOutputs
21
25
  from vellum.workflows.state.base import BaseState, StateMeta
22
26
  from vellum.workflows.state.context import WorkflowContext
@@ -35,8 +39,8 @@ def test_port_condition_match_function_name():
35
39
  """
36
40
  Test that the port condition correctly matches the function name.
37
41
  """
38
- # GIVEN a tool router node
39
- router_node = create_tool_router_node(
42
+ # GIVEN a tool prompt node
43
+ tool_prompt_node = create_tool_prompt_node(
40
44
  ml_model="test-model",
41
45
  blocks=[],
42
46
  functions=[first_function, second_function],
@@ -44,11 +48,17 @@ def test_port_condition_match_function_name():
44
48
  parameters=DEFAULT_PROMPT_PARAMETERS,
45
49
  )
46
50
 
51
+ # AND a router node that references the tool prompt node
52
+ router_node = create_router_node(
53
+ functions=[first_function, second_function],
54
+ tool_prompt_node=tool_prompt_node,
55
+ )
56
+
47
57
  # AND a state with a function call to the first function
48
58
  state = ToolCallingState(
49
59
  meta=StateMeta(
50
60
  node_outputs={
51
- router_node.Outputs.results: [
61
+ tool_prompt_node.Outputs.results: [
52
62
  FunctionCallVellumValue(
53
63
  value=FunctionCall(
54
64
  arguments={}, id="call_zp7pBQjGAOBCr7lo0AbR1HXT", name="first_function", state="FULFILLED"
@@ -93,8 +103,8 @@ def test_tool_calling_node_inline_workflow_context():
93
103
  class Outputs(BaseOutputs):
94
104
  generated_files = MyNode.Outputs.generated_files
95
105
 
96
- # GIVEN a tool router node
97
- tool_router_node = create_tool_router_node(
106
+ # GIVEN a tool prompt node
107
+ tool_prompt_node = create_tool_prompt_node(
98
108
  ml_model="test-model",
99
109
  blocks=[],
100
110
  functions=[MyWorkflow],
@@ -105,7 +115,7 @@ def test_tool_calling_node_inline_workflow_context():
105
115
  # WHEN we create a function node for the workflow
106
116
  function_node_class = create_function_node(
107
117
  function=MyWorkflow,
108
- tool_router_node=tool_router_node,
118
+ tool_prompt_node=tool_prompt_node,
109
119
  )
110
120
 
111
121
  # AND we create an instance with a context containing generated_files
@@ -13,7 +13,7 @@ from vellum.workflows import BaseWorkflow
13
13
  from vellum.workflows.inputs.base import BaseInputs
14
14
  from vellum.workflows.nodes.bases import BaseNode
15
15
  from vellum.workflows.nodes.displayable.tool_calling_node.utils import (
16
- create_tool_router_node,
16
+ create_tool_prompt_node,
17
17
  get_function_name,
18
18
  get_mcp_tool_name,
19
19
  )
@@ -104,9 +104,9 @@ def test_get_function_name_composio_tool_definition_various_toolkits(
104
104
  assert result == expected_result
105
105
 
106
106
 
107
- def test_create_tool_router_node_max_prompt_iterations(vellum_adhoc_prompt_client):
107
+ def test_create_tool_prompt_node_max_prompt_iterations(vellum_adhoc_prompt_client):
108
108
  # GIVEN a tool router node with max_prompt_iterations set to None
109
- tool_router_node = create_tool_router_node(
109
+ tool_prompt_node = create_tool_prompt_node(
110
110
  ml_model="gpt-4o-mini",
111
111
  blocks=[],
112
112
  functions=[],
@@ -129,7 +129,7 @@ def test_create_tool_router_node_max_prompt_iterations(vellum_adhoc_prompt_clien
129
129
  vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
130
130
 
131
131
  # WHEN we run the tool router node
132
- node_instance = tool_router_node()
132
+ node_instance = tool_prompt_node()
133
133
  outputs = list(node_instance.run())
134
134
  assert outputs[0].name == "results"
135
135
  assert outputs[0].value == [StringVellumValue(type="STRING", value="test output")]
@@ -137,7 +137,7 @@ def test_create_tool_router_node_max_prompt_iterations(vellum_adhoc_prompt_clien
137
137
  assert outputs[1].value == "test output"
138
138
 
139
139
 
140
- def test_create_tool_router_node_chat_history_block_dict(vellum_adhoc_prompt_client):
140
+ def test_create_tool_prompt_node_chat_history_block_dict(vellum_adhoc_prompt_client):
141
141
  # GIVEN a list of blocks with a chat history block
142
142
  blocks = [
143
143
  {
@@ -165,7 +165,7 @@ def test_create_tool_router_node_chat_history_block_dict(vellum_adhoc_prompt_cli
165
165
  },
166
166
  ]
167
167
 
168
- tool_router_node = create_tool_router_node(
168
+ tool_prompt_node = create_tool_prompt_node(
169
169
  ml_model="gpt-4o-mini",
170
170
  blocks=blocks, # type: ignore
171
171
  functions=[],
@@ -187,7 +187,7 @@ def test_create_tool_router_node_chat_history_block_dict(vellum_adhoc_prompt_cli
187
187
  vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
188
188
 
189
189
  # WHEN we run the tool router node
190
- node_instance = tool_router_node()
190
+ node_instance = tool_prompt_node()
191
191
  list(node_instance.run())
192
192
 
193
193
  # THEN the API was called with compiled blocks