vellum-ai 0.14.49__py3-none-any.whl → 0.14.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +6 -2
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +6 -2
- vellum/client/types/deployment_read.py +1 -1
- vellum/client/types/slim_workflow_execution_read.py +2 -2
- vellum/client/types/workflow_event_execution_read.py +2 -2
- vellum/client/types/{workflow_execution_usage_calculation_fulfilled_body.py → workflow_execution_usage_calculation_error.py} +5 -6
- vellum/client/types/workflow_execution_usage_calculation_error_code_enum.py +7 -0
- vellum/client/types/workflow_execution_usage_result.py +24 -0
- vellum/types/{workflow_execution_usage_calculation_fulfilled_body.py → workflow_execution_usage_calculation_error.py} +1 -1
- vellum/types/workflow_execution_usage_calculation_error_code_enum.py +3 -0
- vellum/types/workflow_execution_usage_result.py +3 -0
- vellum/workflows/nodes/core/map_node/node.py +74 -87
- vellum/workflows/nodes/core/map_node/tests/test_node.py +49 -0
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +1 -2
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +3 -3
- vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +1 -1
- vellum/workflows/nodes/experimental/__init__.py +3 -0
- vellum/workflows/nodes/experimental/tool_calling_node/tests/test_tool_calling_node.py +53 -0
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +12 -5
- vellum/workflows/state/encoder.py +4 -0
- vellum/workflows/workflows/base.py +8 -0
- {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/RECORD +35 -29
- vellum_ee/workflows/display/nodes/base_node_display.py +31 -2
- vellum_ee/workflows/display/nodes/get_node_display_class.py +1 -24
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +29 -12
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +33 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +177 -0
- vellum_ee/workflows/display/utils/expressions.py +1 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +3 -24
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +3 -3
- {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.49.dist-info → vellum_ai-0.14.51.dist-info}/entry_points.txt +0 -0
vellum/__init__.py
CHANGED
@@ -544,7 +544,9 @@ from .types import (
|
|
544
544
|
WorkflowExecutionSpanAttributes,
|
545
545
|
WorkflowExecutionStreamingBody,
|
546
546
|
WorkflowExecutionStreamingEvent,
|
547
|
-
|
547
|
+
WorkflowExecutionUsageCalculationError,
|
548
|
+
WorkflowExecutionUsageCalculationErrorCodeEnum,
|
549
|
+
WorkflowExecutionUsageResult,
|
548
550
|
WorkflowExecutionViewOnlineEvalMetricResult,
|
549
551
|
WorkflowExecutionWorkflowResultEvent,
|
550
552
|
WorkflowExpandMetaRequest,
|
@@ -1177,7 +1179,9 @@ __all__ = [
|
|
1177
1179
|
"WorkflowExecutionSpanAttributes",
|
1178
1180
|
"WorkflowExecutionStreamingBody",
|
1179
1181
|
"WorkflowExecutionStreamingEvent",
|
1180
|
-
"
|
1182
|
+
"WorkflowExecutionUsageCalculationError",
|
1183
|
+
"WorkflowExecutionUsageCalculationErrorCodeEnum",
|
1184
|
+
"WorkflowExecutionUsageResult",
|
1181
1185
|
"WorkflowExecutionViewOnlineEvalMetricResult",
|
1182
1186
|
"WorkflowExecutionWorkflowResultEvent",
|
1183
1187
|
"WorkflowExpandMetaRequest",
|
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.51",
|
22
22
|
}
|
23
23
|
headers["X-API-KEY"] = self.api_key
|
24
24
|
return headers
|
vellum/client/types/__init__.py
CHANGED
@@ -568,7 +568,9 @@ from .workflow_execution_span import WorkflowExecutionSpan
|
|
568
568
|
from .workflow_execution_span_attributes import WorkflowExecutionSpanAttributes
|
569
569
|
from .workflow_execution_streaming_body import WorkflowExecutionStreamingBody
|
570
570
|
from .workflow_execution_streaming_event import WorkflowExecutionStreamingEvent
|
571
|
-
from .
|
571
|
+
from .workflow_execution_usage_calculation_error import WorkflowExecutionUsageCalculationError
|
572
|
+
from .workflow_execution_usage_calculation_error_code_enum import WorkflowExecutionUsageCalculationErrorCodeEnum
|
573
|
+
from .workflow_execution_usage_result import WorkflowExecutionUsageResult
|
572
574
|
from .workflow_execution_view_online_eval_metric_result import WorkflowExecutionViewOnlineEvalMetricResult
|
573
575
|
from .workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
|
574
576
|
from .workflow_expand_meta_request import WorkflowExpandMetaRequest
|
@@ -1154,7 +1156,9 @@ __all__ = [
|
|
1154
1156
|
"WorkflowExecutionSpanAttributes",
|
1155
1157
|
"WorkflowExecutionStreamingBody",
|
1156
1158
|
"WorkflowExecutionStreamingEvent",
|
1157
|
-
"
|
1159
|
+
"WorkflowExecutionUsageCalculationError",
|
1160
|
+
"WorkflowExecutionUsageCalculationErrorCodeEnum",
|
1161
|
+
"WorkflowExecutionUsageResult",
|
1158
1162
|
"WorkflowExecutionViewOnlineEvalMetricResult",
|
1159
1163
|
"WorkflowExecutionWorkflowResultEvent",
|
1160
1164
|
"WorkflowExpandMetaRequest",
|
@@ -50,7 +50,7 @@ class DeploymentRead(UniversalBaseModel):
|
|
50
50
|
|
51
51
|
active_model_version_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
|
52
52
|
"""
|
53
|
-
Deprecated. This now always returns
|
53
|
+
Deprecated. This now always returns an empty array.
|
54
54
|
"""
|
55
55
|
|
56
56
|
last_deployed_history_item_id: str = pydantic.Field()
|
@@ -15,7 +15,7 @@ from .execution_vellum_value import ExecutionVellumValue
|
|
15
15
|
from .workflow_error import WorkflowError
|
16
16
|
from .workflow_execution_actual import WorkflowExecutionActual
|
17
17
|
from .workflow_execution_view_online_eval_metric_result import WorkflowExecutionViewOnlineEvalMetricResult
|
18
|
-
from .
|
18
|
+
from .workflow_execution_usage_result import WorkflowExecutionUsageResult
|
19
19
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
20
20
|
import pydantic
|
21
21
|
|
@@ -30,7 +30,7 @@ class SlimWorkflowExecutionRead(UniversalBaseModel):
|
|
30
30
|
error: typing.Optional[WorkflowError] = None
|
31
31
|
latest_actual: typing.Optional[WorkflowExecutionActual] = None
|
32
32
|
metric_results: typing.List[WorkflowExecutionViewOnlineEvalMetricResult]
|
33
|
-
usage_results: typing.List[
|
33
|
+
usage_results: typing.Optional[typing.List[WorkflowExecutionUsageResult]] = None
|
34
34
|
|
35
35
|
if IS_PYDANTIC_V2:
|
36
36
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -15,7 +15,7 @@ from .execution_vellum_value import ExecutionVellumValue
|
|
15
15
|
from .workflow_error import WorkflowError
|
16
16
|
from .workflow_execution_actual import WorkflowExecutionActual
|
17
17
|
from .workflow_execution_view_online_eval_metric_result import WorkflowExecutionViewOnlineEvalMetricResult
|
18
|
-
from .
|
18
|
+
from .workflow_execution_usage_result import WorkflowExecutionUsageResult
|
19
19
|
from .vellum_span import VellumSpan
|
20
20
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
21
21
|
import pydantic
|
@@ -31,7 +31,7 @@ class WorkflowEventExecutionRead(UniversalBaseModel):
|
|
31
31
|
error: typing.Optional[WorkflowError] = None
|
32
32
|
latest_actual: typing.Optional[WorkflowExecutionActual] = None
|
33
33
|
metric_results: typing.List[WorkflowExecutionViewOnlineEvalMetricResult]
|
34
|
-
usage_results: typing.List[
|
34
|
+
usage_results: typing.Optional[typing.List[WorkflowExecutionUsageResult]] = None
|
35
35
|
spans: typing.List[VellumSpan]
|
36
36
|
|
37
37
|
if IS_PYDANTIC_V2:
|
@@ -1,16 +1,15 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
-
import
|
5
|
-
from .ml_model_usage_wrapper import MlModelUsageWrapper
|
6
|
-
from .price import Price
|
4
|
+
from .workflow_execution_usage_calculation_error_code_enum import WorkflowExecutionUsageCalculationErrorCodeEnum
|
7
5
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
6
|
+
import typing
|
8
7
|
import pydantic
|
9
8
|
|
10
9
|
|
11
|
-
class
|
12
|
-
|
13
|
-
|
10
|
+
class WorkflowExecutionUsageCalculationError(UniversalBaseModel):
|
11
|
+
code: WorkflowExecutionUsageCalculationErrorCodeEnum
|
12
|
+
message: str
|
14
13
|
|
15
14
|
if IS_PYDANTIC_V2:
|
16
15
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import typing
|
5
|
+
from .ml_model_usage_wrapper import MlModelUsageWrapper
|
6
|
+
from .price import Price
|
7
|
+
from .workflow_execution_usage_calculation_error import WorkflowExecutionUsageCalculationError
|
8
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
9
|
+
import pydantic
|
10
|
+
|
11
|
+
|
12
|
+
class WorkflowExecutionUsageResult(UniversalBaseModel):
|
13
|
+
usage: typing.Optional[typing.List[MlModelUsageWrapper]] = None
|
14
|
+
cost: typing.Optional[typing.List[Price]] = None
|
15
|
+
error: typing.Optional[WorkflowExecutionUsageCalculationError] = None
|
16
|
+
|
17
|
+
if IS_PYDANTIC_V2:
|
18
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
19
|
+
else:
|
20
|
+
|
21
|
+
class Config:
|
22
|
+
frozen = True
|
23
|
+
smart_union = True
|
24
|
+
extra = pydantic.Extra.allow
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
+
import concurrent.futures
|
2
3
|
import logging
|
3
4
|
from queue import Empty, Queue
|
4
|
-
from threading import Thread
|
5
5
|
from typing import (
|
6
6
|
TYPE_CHECKING,
|
7
7
|
Callable,
|
@@ -75,95 +75,90 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
75
75
|
return
|
76
76
|
|
77
77
|
self._event_queue: Queue[Tuple[int, WorkflowEvent]] = Queue()
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
78
|
+
fulfilled_iterations: List[bool] = [False] * len(self.items)
|
79
|
+
|
80
|
+
max_workers = self.max_concurrency if self.max_concurrency is not None else len(self.items)
|
81
|
+
|
82
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
83
|
+
futures = []
|
82
84
|
current_execution_context = get_execution_context()
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
85
|
+
for index, item in enumerate(self.items):
|
86
|
+
future = executor.submit(
|
87
|
+
self._context_run_subworkflow,
|
88
|
+
item=item,
|
89
|
+
index=index,
|
90
|
+
current_execution_context=current_execution_context,
|
91
|
+
)
|
92
|
+
futures.append(future)
|
93
|
+
|
94
|
+
while not all(fulfilled_iterations):
|
95
|
+
try:
|
96
|
+
map_node_event = self._event_queue.get(block=False)
|
97
|
+
index = map_node_event[0]
|
98
|
+
subworkflow_event = map_node_event[1]
|
99
|
+
self._context._emit_subworkflow_event(subworkflow_event)
|
100
|
+
|
101
|
+
if not is_workflow_event(subworkflow_event):
|
102
|
+
continue
|
103
|
+
|
104
|
+
if subworkflow_event.workflow_definition != self.subworkflow:
|
105
|
+
continue
|
106
|
+
|
107
|
+
if subworkflow_event.name == "workflow.execution.initiated":
|
108
|
+
for output_name in mapped_items.keys():
|
109
|
+
yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
|
110
|
+
|
111
|
+
elif subworkflow_event.name == "workflow.execution.fulfilled":
|
112
|
+
for output_reference, output_value in subworkflow_event.outputs:
|
113
|
+
if not isinstance(output_reference, OutputReference):
|
114
|
+
logger.error(
|
115
|
+
"Invalid key to map node's subworkflow event outputs",
|
116
|
+
extra={"output_reference_type": type(output_reference)},
|
117
|
+
)
|
118
|
+
continue
|
119
|
+
|
120
|
+
output_mapped_items = mapped_items[output_reference.name]
|
121
|
+
if index < 0 or index >= len(output_mapped_items):
|
122
|
+
logger.error(
|
123
|
+
"Invalid map node index",
|
124
|
+
extra={"index": index, "output_name": output_reference.name},
|
125
|
+
)
|
126
|
+
continue
|
127
|
+
|
128
|
+
output_mapped_items[index] = output_value
|
129
|
+
yield BaseOutput(
|
130
|
+
name=output_reference.name,
|
131
|
+
delta=(output_value, index, "FULFILLED"),
|
127
132
|
)
|
128
|
-
continue
|
129
133
|
|
130
|
-
|
131
|
-
if index < 0 or index >= len(output_mapped_items):
|
132
|
-
logger.error(
|
133
|
-
"Invalid map node index", extra={"index": index, "output_name": output_reference.name}
|
134
|
-
)
|
135
|
-
continue
|
134
|
+
fulfilled_iterations[index] = True
|
136
135
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
136
|
+
elif subworkflow_event.name == "workflow.execution.paused":
|
137
|
+
raise NodeException(
|
138
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
139
|
+
message=f"Subworkflow unexpectedly paused on iteration {index}",
|
141
140
|
)
|
141
|
+
elif subworkflow_event.name == "workflow.execution.rejected":
|
142
|
+
raise NodeException(
|
143
|
+
f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
|
144
|
+
code=subworkflow_event.error.code,
|
145
|
+
)
|
146
|
+
except Empty:
|
147
|
+
all_futures_done = all(future.done() for future in futures)
|
142
148
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
raise NodeException(
|
151
|
-
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
152
|
-
message=f"Subworkflow unexpectedly paused on iteration {index}",
|
153
|
-
)
|
154
|
-
elif subworkflow_event.name == "workflow.execution.rejected":
|
155
|
-
raise NodeException(
|
156
|
-
f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
|
157
|
-
code=subworkflow_event.error.code,
|
158
|
-
)
|
159
|
-
except Empty:
|
160
|
-
pass
|
149
|
+
if all_futures_done:
|
150
|
+
if not all(fulfilled_iterations):
|
151
|
+
if self._event_queue.empty():
|
152
|
+
logger.warning("All threads completed but not all iterations fulfilled")
|
153
|
+
break
|
154
|
+
else:
|
155
|
+
break
|
161
156
|
|
162
157
|
for output_name, output_list in mapped_items.items():
|
163
158
|
yield BaseOutput(name=output_name, value=output_list)
|
164
159
|
|
165
160
|
def _context_run_subworkflow(
|
166
|
-
self,
|
161
|
+
self, item: MapNodeItemType, index: int, current_execution_context: ExecutionContext
|
167
162
|
) -> None:
|
168
163
|
parent_context = current_execution_context.parent_context
|
169
164
|
trace_id = current_execution_context.trace_id
|
@@ -186,14 +181,6 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
186
181
|
for event in events:
|
187
182
|
self._event_queue.put((index, event))
|
188
183
|
|
189
|
-
def _start_thread(self) -> bool:
|
190
|
-
if self._concurrency_queue.empty():
|
191
|
-
return False
|
192
|
-
|
193
|
-
thread = self._concurrency_queue.get()
|
194
|
-
thread.start()
|
195
|
-
return True
|
196
|
-
|
197
184
|
@overload
|
198
185
|
@classmethod
|
199
186
|
def wrap(
|
@@ -1,3 +1,5 @@
|
|
1
|
+
import datetime
|
2
|
+
import threading
|
1
3
|
import time
|
2
4
|
|
3
5
|
from vellum.workflows.inputs.base import BaseInputs
|
@@ -172,3 +174,50 @@ def test_map_node__nested_map_node():
|
|
172
174
|
["apple carrot", "apple potato"],
|
173
175
|
["banana carrot", "banana potato"],
|
174
176
|
]
|
177
|
+
|
178
|
+
|
179
|
+
def test_map_node_parallel_execution_with_workflow():
|
180
|
+
# TODO: Find a better way to test this such that it represents what a user would see.
|
181
|
+
# https://linear.app/vellum/issue/APO-482/find-a-better-way-to-test-concurrency-with-map-nodes
|
182
|
+
thread_ids = {}
|
183
|
+
|
184
|
+
# GIVEN a series of nodes that simulate work
|
185
|
+
class BaseNode1(BaseNode):
|
186
|
+
item = MapNode.SubworkflowInputs.item
|
187
|
+
|
188
|
+
class Outputs(BaseOutputs):
|
189
|
+
output: str
|
190
|
+
thread_id: int
|
191
|
+
|
192
|
+
def run(self) -> Outputs:
|
193
|
+
current_thread_id = threading.get_ident()
|
194
|
+
thread_ids[self.item] = current_thread_id
|
195
|
+
|
196
|
+
# Simulate work
|
197
|
+
time.sleep(0.01)
|
198
|
+
|
199
|
+
end = time.time()
|
200
|
+
end_str = datetime.datetime.fromtimestamp(end).strftime("%Y-%m-%d %H:%M:%S.%f")
|
201
|
+
|
202
|
+
return self.Outputs(output=end_str, thread_id=current_thread_id)
|
203
|
+
|
204
|
+
# AND a workflow that connects these nodes
|
205
|
+
class TestWorkflow(BaseWorkflow[MapNode.SubworkflowInputs, BaseState]):
|
206
|
+
graph = BaseNode1
|
207
|
+
|
208
|
+
class Outputs(BaseWorkflow.Outputs):
|
209
|
+
final_output = BaseNode1.Outputs.output
|
210
|
+
thread_id = BaseNode1.Outputs.thread_id
|
211
|
+
|
212
|
+
# AND a map node that uses this workflow
|
213
|
+
class TestMapNode(MapNode):
|
214
|
+
items = [1, 2, 3]
|
215
|
+
subworkflow = TestWorkflow
|
216
|
+
|
217
|
+
# WHEN we run the map node
|
218
|
+
node = TestMapNode()
|
219
|
+
list(node.run())
|
220
|
+
|
221
|
+
# AND each item should have run on a different thread
|
222
|
+
thread_ids_list = list(thread_ids.values())
|
223
|
+
assert len(set(thread_ids_list)) == 3
|
@@ -20,7 +20,6 @@ from vellum.client import ApiError, RequestOptions
|
|
20
20
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
21
21
|
from vellum.client.types.prompt_settings import PromptSettings
|
22
22
|
from vellum.client.types.rich_text_child_block import RichTextChildBlock
|
23
|
-
from vellum.workflows.constants import OMIT
|
24
23
|
from vellum.workflows.context import get_execution_context
|
25
24
|
from vellum.workflows.errors import WorkflowErrorCode
|
26
25
|
from vellum.workflows.errors.types import vellum_error_to_workflow_error
|
@@ -56,7 +55,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
56
55
|
functions: Optional[List[Union[FunctionDefinition, Callable]]] = None
|
57
56
|
|
58
57
|
parameters: PromptParameters = DEFAULT_PROMPT_PARAMETERS
|
59
|
-
expand_meta: Optional[AdHocExpandMeta] =
|
58
|
+
expand_meta: Optional[AdHocExpandMeta] = None
|
60
59
|
|
61
60
|
settings: Optional[PromptSettings] = None
|
62
61
|
|
vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py
CHANGED
@@ -270,7 +270,7 @@ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
|
|
270
270
|
# AND we should have made the expected call to Vellum search
|
271
271
|
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.assert_called_once_with(
|
272
272
|
blocks=[],
|
273
|
-
expand_meta=
|
273
|
+
expand_meta=None,
|
274
274
|
functions=None,
|
275
275
|
input_values=[],
|
276
276
|
input_variables=[],
|
@@ -350,7 +350,7 @@ def test_inline_prompt_node__streaming_disabled(vellum_adhoc_prompt_client):
|
|
350
350
|
# AND we should have made the expected call to Vellum search
|
351
351
|
vellum_adhoc_prompt_client.adhoc_execute_prompt.assert_called_once_with(
|
352
352
|
blocks=[],
|
353
|
-
expand_meta=
|
353
|
+
expand_meta=None,
|
354
354
|
functions=None,
|
355
355
|
input_values=[],
|
356
356
|
input_variables=[],
|
@@ -444,7 +444,7 @@ def test_inline_prompt_node__json_output_with_streaming_disabled(vellum_adhoc_pr
|
|
444
444
|
# AND we should have made the expected call to Vellum search
|
445
445
|
vellum_adhoc_prompt_client.adhoc_execute_prompt.assert_called_once_with(
|
446
446
|
blocks=[],
|
447
|
-
expand_meta=
|
447
|
+
expand_meta=None,
|
448
448
|
functions=None,
|
449
449
|
input_values=[],
|
450
450
|
input_variables=[],
|
@@ -74,7 +74,7 @@ def test_inline_text_prompt_node__basic(vellum_adhoc_prompt_client):
|
|
74
74
|
# AND we should have made the expected call to Vellum search
|
75
75
|
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.assert_called_once_with(
|
76
76
|
blocks=[],
|
77
|
-
expand_meta=
|
77
|
+
expand_meta=None,
|
78
78
|
functions=None,
|
79
79
|
input_values=[],
|
80
80
|
input_variables=[],
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from vellum.client.types.function_call import FunctionCall
|
2
|
+
from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
|
3
|
+
from vellum.workflows.nodes.experimental.tool_calling_node.utils import create_tool_router_node
|
4
|
+
from vellum.workflows.state.base import BaseState, StateMeta
|
5
|
+
|
6
|
+
|
7
|
+
def first_function() -> str:
|
8
|
+
return "first_function"
|
9
|
+
|
10
|
+
|
11
|
+
def second_function() -> str:
|
12
|
+
return "second_function"
|
13
|
+
|
14
|
+
|
15
|
+
def test_port_condition_match_function_name():
|
16
|
+
"""
|
17
|
+
Test that the port condition correctly matches the function name.
|
18
|
+
"""
|
19
|
+
# GIVEN a tool router node
|
20
|
+
router_node = create_tool_router_node(
|
21
|
+
ml_model="test-model",
|
22
|
+
blocks=[],
|
23
|
+
functions=[first_function, second_function],
|
24
|
+
prompt_inputs=None,
|
25
|
+
)
|
26
|
+
|
27
|
+
# AND a state with a function call to the first function
|
28
|
+
state = BaseState(
|
29
|
+
meta=StateMeta(
|
30
|
+
node_outputs={
|
31
|
+
router_node.Outputs.results: [
|
32
|
+
FunctionCallVellumValue(
|
33
|
+
value=FunctionCall(
|
34
|
+
arguments={}, id="call_zp7pBQjGAOBCr7lo0AbR1HXT", name="first_function", state="FULFILLED"
|
35
|
+
),
|
36
|
+
)
|
37
|
+
],
|
38
|
+
},
|
39
|
+
)
|
40
|
+
)
|
41
|
+
|
42
|
+
# WHEN the port condition is resolved
|
43
|
+
# THEN the first function port should be true
|
44
|
+
first_function_port = getattr(router_node.Ports, "first_function")
|
45
|
+
assert first_function_port.resolve_condition(state) is True
|
46
|
+
|
47
|
+
# AND the second function port should be false
|
48
|
+
second_function_port = getattr(router_node.Ports, "second_function")
|
49
|
+
assert second_function_port.resolve_condition(state) is False
|
50
|
+
|
51
|
+
# AND the default port should be false
|
52
|
+
default_port = getattr(router_node.Ports, "default")
|
53
|
+
assert default_port.resolve_condition(state) is False
|
@@ -57,12 +57,19 @@ def create_tool_router_node(
|
|
57
57
|
Ports = type("Ports", (), {})
|
58
58
|
for function in functions:
|
59
59
|
function_name = function.__name__
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
60
|
+
|
61
|
+
# Avoid using lambda to capture function_name
|
62
|
+
# lambda will capture the function_name by reference,
|
63
|
+
# and if the function_name is changed, the port_condition will also change.
|
64
|
+
def create_port_condition(fn_name):
|
65
|
+
return LazyReference(
|
66
|
+
lambda: (
|
67
|
+
node.Outputs.results[0]["type"].equals("FUNCTION_CALL")
|
68
|
+
& node.Outputs.results[0]["value"]["name"].equals(fn_name)
|
69
|
+
)
|
64
70
|
)
|
65
|
-
|
71
|
+
|
72
|
+
port_condition = create_port_condition(function_name)
|
66
73
|
port = Port.on_if(port_condition)
|
67
74
|
setattr(Ports, function_name, port)
|
68
75
|
|
@@ -13,6 +13,7 @@ from vellum.workflows.inputs.base import BaseInputs
|
|
13
13
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
14
14
|
from vellum.workflows.ports.port import Port
|
15
15
|
from vellum.workflows.state.base import BaseState, NodeExecutionCache
|
16
|
+
from vellum.workflows.utils.functions import compile_function_definition
|
16
17
|
|
17
18
|
|
18
19
|
class DefaultStateEncoder(JSONEncoder):
|
@@ -57,6 +58,9 @@ class DefaultStateEncoder(JSONEncoder):
|
|
57
58
|
if isinstance(obj, type):
|
58
59
|
return str(obj)
|
59
60
|
|
61
|
+
if callable(obj):
|
62
|
+
return compile_function_definition(obj)
|
63
|
+
|
60
64
|
if obj.__class__ in self.encoders:
|
61
65
|
return self.encoders[obj.__class__](obj)
|
62
66
|
|
@@ -276,6 +276,14 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
276
276
|
"""
|
277
277
|
return cls._get_edges_from_subgraphs(cls.get_unused_subgraphs())
|
278
278
|
|
279
|
+
@classmethod
|
280
|
+
def get_all_nodes(cls) -> Iterator[Type[BaseNode]]:
|
281
|
+
"""
|
282
|
+
Returns an iterator over all nodes in the Workflow, used or unused.
|
283
|
+
"""
|
284
|
+
yield from cls.get_nodes()
|
285
|
+
yield from cls.get_unused_nodes()
|
286
|
+
|
279
287
|
@classmethod
|
280
288
|
def get_entrypoints(cls) -> Iterable[Type[BaseNode]]:
|
281
289
|
return iter({e for g in cls.get_subgraphs() for e in g.entrypoints})
|