vellum-ai 0.14.49__py3-none-any.whl → 0.14.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. vellum/__init__.py +6 -2
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +6 -2
  4. vellum/client/types/deployment_read.py +1 -1
  5. vellum/client/types/slim_workflow_execution_read.py +2 -2
  6. vellum/client/types/workflow_event_execution_read.py +2 -2
  7. vellum/client/types/{workflow_execution_usage_calculation_fulfilled_body.py → workflow_execution_usage_calculation_error.py} +5 -6
  8. vellum/client/types/workflow_execution_usage_calculation_error_code_enum.py +7 -0
  9. vellum/client/types/workflow_execution_usage_result.py +24 -0
  10. vellum/types/{workflow_execution_usage_calculation_fulfilled_body.py → workflow_execution_usage_calculation_error.py} +1 -1
  11. vellum/types/workflow_execution_usage_calculation_error_code_enum.py +3 -0
  12. vellum/types/workflow_execution_usage_result.py +3 -0
  13. vellum/workflows/nodes/core/map_node/node.py +74 -87
  14. vellum/workflows/nodes/core/map_node/tests/test_node.py +49 -0
  15. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +1 -2
  16. vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +3 -3
  17. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +1 -1
  18. vellum/workflows/nodes/experimental/__init__.py +3 -0
  19. vellum/workflows/nodes/experimental/tool_calling_node/tests/test_tool_calling_node.py +53 -0
  20. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +12 -5
  21. vellum/workflows/state/encoder.py +4 -0
  22. vellum/workflows/workflows/base.py +8 -0
  23. {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/METADATA +1 -1
  24. {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/RECORD +35 -29
  25. vellum_ee/workflows/display/nodes/base_node_display.py +31 -2
  26. vellum_ee/workflows/display/nodes/get_node_display_class.py +1 -24
  27. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +29 -12
  28. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +33 -1
  29. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +177 -0
  30. vellum_ee/workflows/display/utils/expressions.py +1 -1
  31. vellum_ee/workflows/display/workflows/base_workflow_display.py +3 -24
  32. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +3 -3
  33. {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/LICENSE +0 -0
  34. {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/WHEEL +0 -0
  35. {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/entry_points.txt +0 -0
vellum/__init__.py CHANGED
@@ -544,7 +544,9 @@ from .types import (
544
544
  WorkflowExecutionSpanAttributes,
545
545
  WorkflowExecutionStreamingBody,
546
546
  WorkflowExecutionStreamingEvent,
547
- WorkflowExecutionUsageCalculationFulfilledBody,
547
+ WorkflowExecutionUsageCalculationError,
548
+ WorkflowExecutionUsageCalculationErrorCodeEnum,
549
+ WorkflowExecutionUsageResult,
548
550
  WorkflowExecutionViewOnlineEvalMetricResult,
549
551
  WorkflowExecutionWorkflowResultEvent,
550
552
  WorkflowExpandMetaRequest,
@@ -1177,7 +1179,9 @@ __all__ = [
1177
1179
  "WorkflowExecutionSpanAttributes",
1178
1180
  "WorkflowExecutionStreamingBody",
1179
1181
  "WorkflowExecutionStreamingEvent",
1180
- "WorkflowExecutionUsageCalculationFulfilledBody",
1182
+ "WorkflowExecutionUsageCalculationError",
1183
+ "WorkflowExecutionUsageCalculationErrorCodeEnum",
1184
+ "WorkflowExecutionUsageResult",
1181
1185
  "WorkflowExecutionViewOnlineEvalMetricResult",
1182
1186
  "WorkflowExecutionWorkflowResultEvent",
1183
1187
  "WorkflowExpandMetaRequest",
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.49",
21
+ "X-Fern-SDK-Version": "0.14.51",
22
22
  }
23
23
  headers["X-API-KEY"] = self.api_key
24
24
  return headers
@@ -568,7 +568,9 @@ from .workflow_execution_span import WorkflowExecutionSpan
568
568
  from .workflow_execution_span_attributes import WorkflowExecutionSpanAttributes
569
569
  from .workflow_execution_streaming_body import WorkflowExecutionStreamingBody
570
570
  from .workflow_execution_streaming_event import WorkflowExecutionStreamingEvent
571
- from .workflow_execution_usage_calculation_fulfilled_body import WorkflowExecutionUsageCalculationFulfilledBody
571
+ from .workflow_execution_usage_calculation_error import WorkflowExecutionUsageCalculationError
572
+ from .workflow_execution_usage_calculation_error_code_enum import WorkflowExecutionUsageCalculationErrorCodeEnum
573
+ from .workflow_execution_usage_result import WorkflowExecutionUsageResult
572
574
  from .workflow_execution_view_online_eval_metric_result import WorkflowExecutionViewOnlineEvalMetricResult
573
575
  from .workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
574
576
  from .workflow_expand_meta_request import WorkflowExpandMetaRequest
@@ -1154,7 +1156,9 @@ __all__ = [
1154
1156
  "WorkflowExecutionSpanAttributes",
1155
1157
  "WorkflowExecutionStreamingBody",
1156
1158
  "WorkflowExecutionStreamingEvent",
1157
- "WorkflowExecutionUsageCalculationFulfilledBody",
1159
+ "WorkflowExecutionUsageCalculationError",
1160
+ "WorkflowExecutionUsageCalculationErrorCodeEnum",
1161
+ "WorkflowExecutionUsageResult",
1158
1162
  "WorkflowExecutionViewOnlineEvalMetricResult",
1159
1163
  "WorkflowExecutionWorkflowResultEvent",
1160
1164
  "WorkflowExpandMetaRequest",
@@ -50,7 +50,7 @@ class DeploymentRead(UniversalBaseModel):
50
50
 
51
51
  active_model_version_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
52
52
  """
53
- Deprecated. This now always returns a null value.
53
+ Deprecated. This now always returns an empty array.
54
54
  """
55
55
 
56
56
  last_deployed_history_item_id: str = pydantic.Field()
@@ -15,7 +15,7 @@ from .execution_vellum_value import ExecutionVellumValue
15
15
  from .workflow_error import WorkflowError
16
16
  from .workflow_execution_actual import WorkflowExecutionActual
17
17
  from .workflow_execution_view_online_eval_metric_result import WorkflowExecutionViewOnlineEvalMetricResult
18
- from .workflow_execution_usage_calculation_fulfilled_body import WorkflowExecutionUsageCalculationFulfilledBody
18
+ from .workflow_execution_usage_result import WorkflowExecutionUsageResult
19
19
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
20
20
  import pydantic
21
21
 
@@ -30,7 +30,7 @@ class SlimWorkflowExecutionRead(UniversalBaseModel):
30
30
  error: typing.Optional[WorkflowError] = None
31
31
  latest_actual: typing.Optional[WorkflowExecutionActual] = None
32
32
  metric_results: typing.List[WorkflowExecutionViewOnlineEvalMetricResult]
33
- usage_results: typing.List[WorkflowExecutionUsageCalculationFulfilledBody]
33
+ usage_results: typing.Optional[typing.List[WorkflowExecutionUsageResult]] = None
34
34
 
35
35
  if IS_PYDANTIC_V2:
36
36
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -15,7 +15,7 @@ from .execution_vellum_value import ExecutionVellumValue
15
15
  from .workflow_error import WorkflowError
16
16
  from .workflow_execution_actual import WorkflowExecutionActual
17
17
  from .workflow_execution_view_online_eval_metric_result import WorkflowExecutionViewOnlineEvalMetricResult
18
- from .workflow_execution_usage_calculation_fulfilled_body import WorkflowExecutionUsageCalculationFulfilledBody
18
+ from .workflow_execution_usage_result import WorkflowExecutionUsageResult
19
19
  from .vellum_span import VellumSpan
20
20
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
21
21
  import pydantic
@@ -31,7 +31,7 @@ class WorkflowEventExecutionRead(UniversalBaseModel):
31
31
  error: typing.Optional[WorkflowError] = None
32
32
  latest_actual: typing.Optional[WorkflowExecutionActual] = None
33
33
  metric_results: typing.List[WorkflowExecutionViewOnlineEvalMetricResult]
34
- usage_results: typing.List[WorkflowExecutionUsageCalculationFulfilledBody]
34
+ usage_results: typing.Optional[typing.List[WorkflowExecutionUsageResult]] = None
35
35
  spans: typing.List[VellumSpan]
36
36
 
37
37
  if IS_PYDANTIC_V2:
@@ -1,16 +1,15 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  from ..core.pydantic_utilities import UniversalBaseModel
4
- import typing
5
- from .ml_model_usage_wrapper import MlModelUsageWrapper
6
- from .price import Price
4
+ from .workflow_execution_usage_calculation_error_code_enum import WorkflowExecutionUsageCalculationErrorCodeEnum
7
5
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
6
+ import typing
8
7
  import pydantic
9
8
 
10
9
 
11
- class WorkflowExecutionUsageCalculationFulfilledBody(UniversalBaseModel):
12
- usage: typing.List[MlModelUsageWrapper]
13
- cost: typing.List[Price]
10
+ class WorkflowExecutionUsageCalculationError(UniversalBaseModel):
11
+ code: WorkflowExecutionUsageCalculationErrorCodeEnum
12
+ message: str
14
13
 
15
14
  if IS_PYDANTIC_V2:
16
15
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -0,0 +1,7 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ import typing
4
+
5
+ WorkflowExecutionUsageCalculationErrorCodeEnum = typing.Union[
6
+ typing.Literal["UNKNOWN", "DEPENDENCIES_FAILED", "NO_USAGE_CALCULATED", "INTERNAL_SERVER_ERROR"], typing.Any
7
+ ]
@@ -0,0 +1,24 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import typing
5
+ from .ml_model_usage_wrapper import MlModelUsageWrapper
6
+ from .price import Price
7
+ from .workflow_execution_usage_calculation_error import WorkflowExecutionUsageCalculationError
8
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
9
+ import pydantic
10
+
11
+
12
+ class WorkflowExecutionUsageResult(UniversalBaseModel):
13
+ usage: typing.Optional[typing.List[MlModelUsageWrapper]] = None
14
+ cost: typing.Optional[typing.List[Price]] = None
15
+ error: typing.Optional[WorkflowExecutionUsageCalculationError] = None
16
+
17
+ if IS_PYDANTIC_V2:
18
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
19
+ else:
20
+
21
+ class Config:
22
+ frozen = True
23
+ smart_union = True
24
+ extra = pydantic.Extra.allow
@@ -1,3 +1,3 @@
1
1
  # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
2
 
3
- from vellum.client.types.workflow_execution_usage_calculation_fulfilled_body import *
3
+ from vellum.client.types.workflow_execution_usage_calculation_error import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.workflow_execution_usage_calculation_error_code_enum import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.workflow_execution_usage_result import *
@@ -1,7 +1,7 @@
1
1
  from collections import defaultdict
2
+ import concurrent.futures
2
3
  import logging
3
4
  from queue import Empty, Queue
4
- from threading import Thread
5
5
  from typing import (
6
6
  TYPE_CHECKING,
7
7
  Callable,
@@ -75,95 +75,90 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
75
75
  return
76
76
 
77
77
  self._event_queue: Queue[Tuple[int, WorkflowEvent]] = Queue()
78
- self._concurrency_queue: Queue[Thread] = Queue()
79
- fulfilled_iterations: List[bool] = []
80
- for index, item in enumerate(self.items):
81
- fulfilled_iterations.append(False)
78
+ fulfilled_iterations: List[bool] = [False] * len(self.items)
79
+
80
+ max_workers = self.max_concurrency if self.max_concurrency is not None else len(self.items)
81
+
82
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
83
+ futures = []
82
84
  current_execution_context = get_execution_context()
83
- thread = Thread(
84
- target=self._context_run_subworkflow,
85
- kwargs={
86
- "item": item,
87
- "index": index,
88
- "current_execution_context": current_execution_context,
89
- },
90
- )
91
- if self.max_concurrency is None:
92
- thread.start()
93
- else:
94
- self._concurrency_queue.put(thread)
95
-
96
- if self.max_concurrency is not None:
97
- concurrency_count = 0
98
- while concurrency_count < self.max_concurrency:
99
- is_empty = self._start_thread()
100
- if is_empty:
101
- break
102
-
103
- concurrency_count += 1
104
-
105
- try:
106
- while map_node_event := self._event_queue.get():
107
- index = map_node_event[0]
108
- subworkflow_event = map_node_event[1]
109
- self._context._emit_subworkflow_event(subworkflow_event)
110
-
111
- if not is_workflow_event(subworkflow_event):
112
- continue
113
-
114
- if subworkflow_event.workflow_definition != self.subworkflow:
115
- continue
116
-
117
- if subworkflow_event.name == "workflow.execution.initiated":
118
- for output_name in mapped_items.keys():
119
- yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
120
-
121
- elif subworkflow_event.name == "workflow.execution.fulfilled":
122
- for output_reference, output_value in subworkflow_event.outputs:
123
- if not isinstance(output_reference, OutputReference):
124
- logger.error(
125
- "Invalid key to map node's subworkflow event outputs",
126
- extra={"output_reference_type": type(output_reference)},
85
+ for index, item in enumerate(self.items):
86
+ future = executor.submit(
87
+ self._context_run_subworkflow,
88
+ item=item,
89
+ index=index,
90
+ current_execution_context=current_execution_context,
91
+ )
92
+ futures.append(future)
93
+
94
+ while not all(fulfilled_iterations):
95
+ try:
96
+ map_node_event = self._event_queue.get(block=False)
97
+ index = map_node_event[0]
98
+ subworkflow_event = map_node_event[1]
99
+ self._context._emit_subworkflow_event(subworkflow_event)
100
+
101
+ if not is_workflow_event(subworkflow_event):
102
+ continue
103
+
104
+ if subworkflow_event.workflow_definition != self.subworkflow:
105
+ continue
106
+
107
+ if subworkflow_event.name == "workflow.execution.initiated":
108
+ for output_name in mapped_items.keys():
109
+ yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
110
+
111
+ elif subworkflow_event.name == "workflow.execution.fulfilled":
112
+ for output_reference, output_value in subworkflow_event.outputs:
113
+ if not isinstance(output_reference, OutputReference):
114
+ logger.error(
115
+ "Invalid key to map node's subworkflow event outputs",
116
+ extra={"output_reference_type": type(output_reference)},
117
+ )
118
+ continue
119
+
120
+ output_mapped_items = mapped_items[output_reference.name]
121
+ if index < 0 or index >= len(output_mapped_items):
122
+ logger.error(
123
+ "Invalid map node index",
124
+ extra={"index": index, "output_name": output_reference.name},
125
+ )
126
+ continue
127
+
128
+ output_mapped_items[index] = output_value
129
+ yield BaseOutput(
130
+ name=output_reference.name,
131
+ delta=(output_value, index, "FULFILLED"),
127
132
  )
128
- continue
129
133
 
130
- output_mapped_items = mapped_items[output_reference.name]
131
- if index < 0 or index >= len(output_mapped_items):
132
- logger.error(
133
- "Invalid map node index", extra={"index": index, "output_name": output_reference.name}
134
- )
135
- continue
134
+ fulfilled_iterations[index] = True
136
135
 
137
- output_mapped_items[index] = output_value
138
- yield BaseOutput(
139
- name=output_reference.name,
140
- delta=(output_value, index, "FULFILLED"),
136
+ elif subworkflow_event.name == "workflow.execution.paused":
137
+ raise NodeException(
138
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
139
+ message=f"Subworkflow unexpectedly paused on iteration {index}",
141
140
  )
141
+ elif subworkflow_event.name == "workflow.execution.rejected":
142
+ raise NodeException(
143
+ f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
144
+ code=subworkflow_event.error.code,
145
+ )
146
+ except Empty:
147
+ all_futures_done = all(future.done() for future in futures)
142
148
 
143
- fulfilled_iterations[index] = True
144
- if all(fulfilled_iterations):
145
- break
146
-
147
- if self.max_concurrency is not None:
148
- self._start_thread()
149
- elif subworkflow_event.name == "workflow.execution.paused":
150
- raise NodeException(
151
- code=WorkflowErrorCode.INVALID_OUTPUTS,
152
- message=f"Subworkflow unexpectedly paused on iteration {index}",
153
- )
154
- elif subworkflow_event.name == "workflow.execution.rejected":
155
- raise NodeException(
156
- f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
157
- code=subworkflow_event.error.code,
158
- )
159
- except Empty:
160
- pass
149
+ if all_futures_done:
150
+ if not all(fulfilled_iterations):
151
+ if self._event_queue.empty():
152
+ logger.warning("All threads completed but not all iterations fulfilled")
153
+ break
154
+ else:
155
+ break
161
156
 
162
157
  for output_name, output_list in mapped_items.items():
163
158
  yield BaseOutput(name=output_name, value=output_list)
164
159
 
165
160
  def _context_run_subworkflow(
166
- self, *, item: MapNodeItemType, index: int, current_execution_context: ExecutionContext
161
+ self, item: MapNodeItemType, index: int, current_execution_context: ExecutionContext
167
162
  ) -> None:
168
163
  parent_context = current_execution_context.parent_context
169
164
  trace_id = current_execution_context.trace_id
@@ -186,14 +181,6 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
186
181
  for event in events:
187
182
  self._event_queue.put((index, event))
188
183
 
189
- def _start_thread(self) -> bool:
190
- if self._concurrency_queue.empty():
191
- return False
192
-
193
- thread = self._concurrency_queue.get()
194
- thread.start()
195
- return True
196
-
197
184
  @overload
198
185
  @classmethod
199
186
  def wrap(
@@ -1,3 +1,5 @@
1
+ import datetime
2
+ import threading
1
3
  import time
2
4
 
3
5
  from vellum.workflows.inputs.base import BaseInputs
@@ -172,3 +174,50 @@ def test_map_node__nested_map_node():
172
174
  ["apple carrot", "apple potato"],
173
175
  ["banana carrot", "banana potato"],
174
176
  ]
177
+
178
+
179
+ def test_map_node_parallel_execution_with_workflow():
180
+ # TODO: Find a better way to test this such that it represents what a user would see.
181
+ # https://linear.app/vellum/issue/APO-482/find-a-better-way-to-test-concurrency-with-map-nodes
182
+ thread_ids = {}
183
+
184
+ # GIVEN a series of nodes that simulate work
185
+ class BaseNode1(BaseNode):
186
+ item = MapNode.SubworkflowInputs.item
187
+
188
+ class Outputs(BaseOutputs):
189
+ output: str
190
+ thread_id: int
191
+
192
+ def run(self) -> Outputs:
193
+ current_thread_id = threading.get_ident()
194
+ thread_ids[self.item] = current_thread_id
195
+
196
+ # Simulate work
197
+ time.sleep(0.01)
198
+
199
+ end = time.time()
200
+ end_str = datetime.datetime.fromtimestamp(end).strftime("%Y-%m-%d %H:%M:%S.%f")
201
+
202
+ return self.Outputs(output=end_str, thread_id=current_thread_id)
203
+
204
+ # AND a workflow that connects these nodes
205
+ class TestWorkflow(BaseWorkflow[MapNode.SubworkflowInputs, BaseState]):
206
+ graph = BaseNode1
207
+
208
+ class Outputs(BaseWorkflow.Outputs):
209
+ final_output = BaseNode1.Outputs.output
210
+ thread_id = BaseNode1.Outputs.thread_id
211
+
212
+ # AND a map node that uses this workflow
213
+ class TestMapNode(MapNode):
214
+ items = [1, 2, 3]
215
+ subworkflow = TestWorkflow
216
+
217
+ # WHEN we run the map node
218
+ node = TestMapNode()
219
+ list(node.run())
220
+
221
+ # AND each item should have run on a different thread
222
+ thread_ids_list = list(thread_ids.values())
223
+ assert len(set(thread_ids_list)) == 3
@@ -20,7 +20,6 @@ from vellum.client import ApiError, RequestOptions
20
20
  from vellum.client.types.chat_message_request import ChatMessageRequest
21
21
  from vellum.client.types.prompt_settings import PromptSettings
22
22
  from vellum.client.types.rich_text_child_block import RichTextChildBlock
23
- from vellum.workflows.constants import OMIT
24
23
  from vellum.workflows.context import get_execution_context
25
24
  from vellum.workflows.errors import WorkflowErrorCode
26
25
  from vellum.workflows.errors.types import vellum_error_to_workflow_error
@@ -56,7 +55,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
56
55
  functions: Optional[List[Union[FunctionDefinition, Callable]]] = None
57
56
 
58
57
  parameters: PromptParameters = DEFAULT_PROMPT_PARAMETERS
59
- expand_meta: Optional[AdHocExpandMeta] = OMIT
58
+ expand_meta: Optional[AdHocExpandMeta] = None
60
59
 
61
60
  settings: Optional[PromptSettings] = None
62
61
 
@@ -270,7 +270,7 @@ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
270
270
  # AND we should have made the expected call to Vellum search
271
271
  vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.assert_called_once_with(
272
272
  blocks=[],
273
- expand_meta=Ellipsis,
273
+ expand_meta=None,
274
274
  functions=None,
275
275
  input_values=[],
276
276
  input_variables=[],
@@ -350,7 +350,7 @@ def test_inline_prompt_node__streaming_disabled(vellum_adhoc_prompt_client):
350
350
  # AND we should have made the expected call to Vellum search
351
351
  vellum_adhoc_prompt_client.adhoc_execute_prompt.assert_called_once_with(
352
352
  blocks=[],
353
- expand_meta=Ellipsis,
353
+ expand_meta=None,
354
354
  functions=None,
355
355
  input_values=[],
356
356
  input_variables=[],
@@ -444,7 +444,7 @@ def test_inline_prompt_node__json_output_with_streaming_disabled(vellum_adhoc_pr
444
444
  # AND we should have made the expected call to Vellum search
445
445
  vellum_adhoc_prompt_client.adhoc_execute_prompt.assert_called_once_with(
446
446
  blocks=[],
447
- expand_meta=Ellipsis,
447
+ expand_meta=None,
448
448
  functions=None,
449
449
  input_values=[],
450
450
  input_variables=[],
@@ -74,7 +74,7 @@ def test_inline_text_prompt_node__basic(vellum_adhoc_prompt_client):
74
74
  # AND we should have made the expected call to Vellum search
75
75
  vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.assert_called_once_with(
76
76
  blocks=[],
77
- expand_meta=Ellipsis,
77
+ expand_meta=None,
78
78
  functions=None,
79
79
  input_values=[],
80
80
  input_variables=[],
@@ -0,0 +1,3 @@
1
+ from .tool_calling_node import ToolCallingNode
2
+
3
+ __all__ = ["ToolCallingNode"]
@@ -0,0 +1,53 @@
1
+ from vellum.client.types.function_call import FunctionCall
2
+ from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
3
+ from vellum.workflows.nodes.experimental.tool_calling_node.utils import create_tool_router_node
4
+ from vellum.workflows.state.base import BaseState, StateMeta
5
+
6
+
7
+ def first_function() -> str:
8
+ return "first_function"
9
+
10
+
11
+ def second_function() -> str:
12
+ return "second_function"
13
+
14
+
15
+ def test_port_condition_match_function_name():
16
+ """
17
+ Test that the port condition correctly matches the function name.
18
+ """
19
+ # GIVEN a tool router node
20
+ router_node = create_tool_router_node(
21
+ ml_model="test-model",
22
+ blocks=[],
23
+ functions=[first_function, second_function],
24
+ prompt_inputs=None,
25
+ )
26
+
27
+ # AND a state with a function call to the first function
28
+ state = BaseState(
29
+ meta=StateMeta(
30
+ node_outputs={
31
+ router_node.Outputs.results: [
32
+ FunctionCallVellumValue(
33
+ value=FunctionCall(
34
+ arguments={}, id="call_zp7pBQjGAOBCr7lo0AbR1HXT", name="first_function", state="FULFILLED"
35
+ ),
36
+ )
37
+ ],
38
+ },
39
+ )
40
+ )
41
+
42
+ # WHEN the port condition is resolved
43
+ # THEN the first function port should be true
44
+ first_function_port = getattr(router_node.Ports, "first_function")
45
+ assert first_function_port.resolve_condition(state) is True
46
+
47
+ # AND the second function port should be false
48
+ second_function_port = getattr(router_node.Ports, "second_function")
49
+ assert second_function_port.resolve_condition(state) is False
50
+
51
+ # AND the default port should be false
52
+ default_port = getattr(router_node.Ports, "default")
53
+ assert default_port.resolve_condition(state) is False
@@ -57,12 +57,19 @@ def create_tool_router_node(
57
57
  Ports = type("Ports", (), {})
58
58
  for function in functions:
59
59
  function_name = function.__name__
60
- port_condition = LazyReference(
61
- lambda: (
62
- node.Outputs.results[0]["type"].equals("FUNCTION_CALL")
63
- & node.Outputs.results[0]["value"]["name"].equals(function_name)
60
+
61
+ # Avoid using lambda to capture function_name
62
+ # lambda will capture the function_name by reference,
63
+ # and if the function_name is changed, the port_condition will also change.
64
+ def create_port_condition(fn_name):
65
+ return LazyReference(
66
+ lambda: (
67
+ node.Outputs.results[0]["type"].equals("FUNCTION_CALL")
68
+ & node.Outputs.results[0]["value"]["name"].equals(fn_name)
69
+ )
64
70
  )
65
- )
71
+
72
+ port_condition = create_port_condition(function_name)
66
73
  port = Port.on_if(port_condition)
67
74
  setattr(Ports, function_name, port)
68
75
 
@@ -13,6 +13,7 @@ from vellum.workflows.inputs.base import BaseInputs
13
13
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
14
14
  from vellum.workflows.ports.port import Port
15
15
  from vellum.workflows.state.base import BaseState, NodeExecutionCache
16
+ from vellum.workflows.utils.functions import compile_function_definition
16
17
 
17
18
 
18
19
  class DefaultStateEncoder(JSONEncoder):
@@ -57,6 +58,9 @@ class DefaultStateEncoder(JSONEncoder):
57
58
  if isinstance(obj, type):
58
59
  return str(obj)
59
60
 
61
+ if callable(obj):
62
+ return compile_function_definition(obj)
63
+
60
64
  if obj.__class__ in self.encoders:
61
65
  return self.encoders[obj.__class__](obj)
62
66
 
@@ -276,6 +276,14 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
276
276
  """
277
277
  return cls._get_edges_from_subgraphs(cls.get_unused_subgraphs())
278
278
 
279
+ @classmethod
280
+ def get_all_nodes(cls) -> Iterator[Type[BaseNode]]:
281
+ """
282
+ Returns an iterator over all nodes in the Workflow, used or unused.
283
+ """
284
+ yield from cls.get_nodes()
285
+ yield from cls.get_unused_nodes()
286
+
279
287
  @classmethod
280
288
  def get_entrypoints(cls) -> Iterable[Type[BaseNode]]:
281
289
  return iter({e for g in cls.get_subgraphs() for e in g.entrypoints})
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 0.14.49
3
+ Version: 0.14.51
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0