vellum-ai 0.14.53__py3-none-any.whl → 0.14.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +4 -0
- vellum/client/__init__.py +4 -4
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +116 -114
- vellum/client/resources/deployments/client.py +121 -2
- vellum/client/resources/release_reviews/client.py +1 -118
- vellum/client/resources/workflow_deployments/client.py +4 -2
- vellum/client/types/__init__.py +4 -0
- vellum/client/types/components_schemas_prompt_version_build_config_sandbox.py +5 -0
- vellum/client/types/execute_api_request_body.py +3 -1
- vellum/client/types/prompt_deployment_release_prompt_version.py +6 -1
- vellum/client/types/prompt_version_build_config_sandbox.py +22 -0
- vellum/types/components_schemas_prompt_version_build_config_sandbox.py +3 -0
- vellum/types/prompt_version_build_config_sandbox.py +3 -0
- vellum/workflows/nodes/bases/base.py +32 -1
- vellum/workflows/nodes/displayable/code_execution_node/tests/{test_code_execution_node.py → test_node.py} +139 -16
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +6 -6
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +4 -1
- vellum/workflows/nodes/tests/test_utils.py +8 -21
- vellum/workflows/nodes/utils.py +4 -1
- vellum/workflows/runner/runner.py +1 -1
- vellum/workflows/state/base.py +0 -18
- vellum/workflows/types/code_execution_node_wrappers.py +28 -0
- vellum/workflows/utils/functions.py +1 -0
- vellum/workflows/utils/tests/test_functions.py +14 -0
- {vellum_ai-0.14.53.dist-info → vellum_ai-0.14.55.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.53.dist-info → vellum_ai-0.14.55.dist-info}/RECORD +36 -32
- vellum_cli/logger.py +11 -0
- vellum_cli/push.py +36 -32
- vellum_cli/tests/test_push.py +31 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +2 -7
- vellum_ee/workflows/display/workflows/base_workflow_display.py +20 -4
- vellum_ee/workflows/tests/test_display_meta.py +48 -0
- {vellum_ai-0.14.53.dist-info → vellum_ai-0.14.55.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.53.dist-info → vellum_ai-0.14.55.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.53.dist-info → vellum_ai-0.14.55.dist-info}/entry_points.txt +0 -0
@@ -3,12 +3,11 @@
|
|
3
3
|
from ...core.client_wrapper import SyncClientWrapper
|
4
4
|
import typing
|
5
5
|
from ...core.request_options import RequestOptions
|
6
|
-
from ...types.
|
6
|
+
from ...types.workflow_deployment_release import WorkflowDeploymentRelease
|
7
7
|
from ...core.jsonable_encoder import jsonable_encoder
|
8
8
|
from ...core.pydantic_utilities import parse_obj_as
|
9
9
|
from json.decoder import JSONDecodeError
|
10
10
|
from ...core.api_error import ApiError
|
11
|
-
from ...types.workflow_deployment_release import WorkflowDeploymentRelease
|
12
11
|
from ...core.client_wrapper import AsyncClientWrapper
|
13
12
|
|
14
13
|
|
@@ -16,60 +15,6 @@ class ReleaseReviewsClient:
|
|
16
15
|
def __init__(self, *, client_wrapper: SyncClientWrapper):
|
17
16
|
self._client_wrapper = client_wrapper
|
18
17
|
|
19
|
-
def retrieve_prompt_deployment_release(
|
20
|
-
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
21
|
-
) -> PromptDeploymentRelease:
|
22
|
-
"""
|
23
|
-
Retrieve a specific Prompt Deployment Release by either its UUID or the name of a Release Tag that points to it.
|
24
|
-
|
25
|
-
Parameters
|
26
|
-
----------
|
27
|
-
id : str
|
28
|
-
A UUID string identifying this deployment.
|
29
|
-
|
30
|
-
release_id_or_release_tag : str
|
31
|
-
Either the UUID of Prompt Deployment Release you'd like to retrieve, or the name of a Release Tag that's pointing to the Prompt Deployment Release you'd like to retrieve.
|
32
|
-
|
33
|
-
request_options : typing.Optional[RequestOptions]
|
34
|
-
Request-specific configuration.
|
35
|
-
|
36
|
-
Returns
|
37
|
-
-------
|
38
|
-
PromptDeploymentRelease
|
39
|
-
|
40
|
-
|
41
|
-
Examples
|
42
|
-
--------
|
43
|
-
from vellum import Vellum
|
44
|
-
|
45
|
-
client = Vellum(
|
46
|
-
api_key="YOUR_API_KEY",
|
47
|
-
)
|
48
|
-
client.release_reviews.retrieve_prompt_deployment_release(
|
49
|
-
id="id",
|
50
|
-
release_id_or_release_tag="release_id_or_release_tag",
|
51
|
-
)
|
52
|
-
"""
|
53
|
-
_response = self._client_wrapper.httpx_client.request(
|
54
|
-
f"v1/deployments/{jsonable_encoder(id)}/releases/{jsonable_encoder(release_id_or_release_tag)}",
|
55
|
-
base_url=self._client_wrapper.get_environment().default,
|
56
|
-
method="GET",
|
57
|
-
request_options=request_options,
|
58
|
-
)
|
59
|
-
try:
|
60
|
-
if 200 <= _response.status_code < 300:
|
61
|
-
return typing.cast(
|
62
|
-
PromptDeploymentRelease,
|
63
|
-
parse_obj_as(
|
64
|
-
type_=PromptDeploymentRelease, # type: ignore
|
65
|
-
object_=_response.json(),
|
66
|
-
),
|
67
|
-
)
|
68
|
-
_response_json = _response.json()
|
69
|
-
except JSONDecodeError:
|
70
|
-
raise ApiError(status_code=_response.status_code, body=_response.text)
|
71
|
-
raise ApiError(status_code=_response.status_code, body=_response_json)
|
72
|
-
|
73
18
|
def retrieve_workflow_deployment_release(
|
74
19
|
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
75
20
|
) -> WorkflowDeploymentRelease:
|
@@ -129,68 +74,6 @@ class AsyncReleaseReviewsClient:
|
|
129
74
|
def __init__(self, *, client_wrapper: AsyncClientWrapper):
|
130
75
|
self._client_wrapper = client_wrapper
|
131
76
|
|
132
|
-
async def retrieve_prompt_deployment_release(
|
133
|
-
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
134
|
-
) -> PromptDeploymentRelease:
|
135
|
-
"""
|
136
|
-
Retrieve a specific Prompt Deployment Release by either its UUID or the name of a Release Tag that points to it.
|
137
|
-
|
138
|
-
Parameters
|
139
|
-
----------
|
140
|
-
id : str
|
141
|
-
A UUID string identifying this deployment.
|
142
|
-
|
143
|
-
release_id_or_release_tag : str
|
144
|
-
Either the UUID of Prompt Deployment Release you'd like to retrieve, or the name of a Release Tag that's pointing to the Prompt Deployment Release you'd like to retrieve.
|
145
|
-
|
146
|
-
request_options : typing.Optional[RequestOptions]
|
147
|
-
Request-specific configuration.
|
148
|
-
|
149
|
-
Returns
|
150
|
-
-------
|
151
|
-
PromptDeploymentRelease
|
152
|
-
|
153
|
-
|
154
|
-
Examples
|
155
|
-
--------
|
156
|
-
import asyncio
|
157
|
-
|
158
|
-
from vellum import AsyncVellum
|
159
|
-
|
160
|
-
client = AsyncVellum(
|
161
|
-
api_key="YOUR_API_KEY",
|
162
|
-
)
|
163
|
-
|
164
|
-
|
165
|
-
async def main() -> None:
|
166
|
-
await client.release_reviews.retrieve_prompt_deployment_release(
|
167
|
-
id="id",
|
168
|
-
release_id_or_release_tag="release_id_or_release_tag",
|
169
|
-
)
|
170
|
-
|
171
|
-
|
172
|
-
asyncio.run(main())
|
173
|
-
"""
|
174
|
-
_response = await self._client_wrapper.httpx_client.request(
|
175
|
-
f"v1/deployments/{jsonable_encoder(id)}/releases/{jsonable_encoder(release_id_or_release_tag)}",
|
176
|
-
base_url=self._client_wrapper.get_environment().default,
|
177
|
-
method="GET",
|
178
|
-
request_options=request_options,
|
179
|
-
)
|
180
|
-
try:
|
181
|
-
if 200 <= _response.status_code < 300:
|
182
|
-
return typing.cast(
|
183
|
-
PromptDeploymentRelease,
|
184
|
-
parse_obj_as(
|
185
|
-
type_=PromptDeploymentRelease, # type: ignore
|
186
|
-
object_=_response.json(),
|
187
|
-
),
|
188
|
-
)
|
189
|
-
_response_json = _response.json()
|
190
|
-
except JSONDecodeError:
|
191
|
-
raise ApiError(status_code=_response.status_code, body=_response.text)
|
192
|
-
raise ApiError(status_code=_response.status_code, body=_response_json)
|
193
|
-
|
194
77
|
async def retrieve_workflow_deployment_release(
|
195
78
|
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
196
79
|
) -> WorkflowDeploymentRelease:
|
@@ -263,7 +263,8 @@ class WorkflowDeploymentsClient:
|
|
263
263
|
self, history_id_or_release_tag: str, id: str, *, request_options: typing.Optional[RequestOptions] = None
|
264
264
|
) -> WorkflowDeploymentHistoryItem:
|
265
265
|
"""
|
266
|
-
|
266
|
+
DEPRECATED: This endpoint is deprecated and will be removed in a future release. Please use the
|
267
|
+
`retrieve_workflow_deployment_release` endpoint instead.
|
267
268
|
|
268
269
|
Parameters
|
269
270
|
----------
|
@@ -786,7 +787,8 @@ class AsyncWorkflowDeploymentsClient:
|
|
786
787
|
self, history_id_or_release_tag: str, id: str, *, request_options: typing.Optional[RequestOptions] = None
|
787
788
|
) -> WorkflowDeploymentHistoryItem:
|
788
789
|
"""
|
789
|
-
|
790
|
+
DEPRECATED: This endpoint is deprecated and will be removed in a future release. Please use the
|
791
|
+
`retrieve_workflow_deployment_release` endpoint instead.
|
790
792
|
|
791
793
|
Parameters
|
792
794
|
----------
|
vellum/client/types/__init__.py
CHANGED
@@ -69,6 +69,7 @@ from .compile_prompt_deployment_expand_meta_request import CompilePromptDeployme
|
|
69
69
|
from .compile_prompt_meta import CompilePromptMeta
|
70
70
|
from .components_schemas_pdf_search_result_meta_source import ComponentsSchemasPdfSearchResultMetaSource
|
71
71
|
from .components_schemas_pdf_search_result_meta_source_request import ComponentsSchemasPdfSearchResultMetaSourceRequest
|
72
|
+
from .components_schemas_prompt_version_build_config_sandbox import ComponentsSchemasPromptVersionBuildConfigSandbox
|
72
73
|
from .condition_combinator import ConditionCombinator
|
73
74
|
from .conditional_node_result import ConditionalNodeResult
|
74
75
|
from .conditional_node_result_data import ConditionalNodeResultData
|
@@ -324,6 +325,7 @@ from .prompt_request_input import PromptRequestInput
|
|
324
325
|
from .prompt_request_json_input import PromptRequestJsonInput
|
325
326
|
from .prompt_request_string_input import PromptRequestStringInput
|
326
327
|
from .prompt_settings import PromptSettings
|
328
|
+
from .prompt_version_build_config_sandbox import PromptVersionBuildConfigSandbox
|
327
329
|
from .raw_prompt_execution_overrides_request import RawPromptExecutionOverridesRequest
|
328
330
|
from .reducto_chunker_config import ReductoChunkerConfig
|
329
331
|
from .reducto_chunker_config_request import ReductoChunkerConfigRequest
|
@@ -677,6 +679,7 @@ __all__ = [
|
|
677
679
|
"CompilePromptMeta",
|
678
680
|
"ComponentsSchemasPdfSearchResultMetaSource",
|
679
681
|
"ComponentsSchemasPdfSearchResultMetaSourceRequest",
|
682
|
+
"ComponentsSchemasPromptVersionBuildConfigSandbox",
|
680
683
|
"ConditionCombinator",
|
681
684
|
"ConditionalNodeResult",
|
682
685
|
"ConditionalNodeResultData",
|
@@ -928,6 +931,7 @@ __all__ = [
|
|
928
931
|
"PromptRequestJsonInput",
|
929
932
|
"PromptRequestStringInput",
|
930
933
|
"PromptSettings",
|
934
|
+
"PromptVersionBuildConfigSandbox",
|
931
935
|
"RawPromptExecutionOverridesRequest",
|
932
936
|
"ReductoChunkerConfig",
|
933
937
|
"ReductoChunkerConfigRequest",
|
@@ -1,13 +1,18 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
from .components_schemas_prompt_version_build_config_sandbox import ComponentsSchemasPromptVersionBuildConfigSandbox
|
5
|
+
import pydantic
|
4
6
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
5
7
|
import typing
|
6
|
-
import pydantic
|
7
8
|
|
8
9
|
|
9
10
|
class PromptDeploymentReleasePromptVersion(UniversalBaseModel):
|
10
11
|
id: str
|
12
|
+
build_config: ComponentsSchemasPromptVersionBuildConfigSandbox = pydantic.Field()
|
13
|
+
"""
|
14
|
+
Configuration used to build this prompt version.
|
15
|
+
"""
|
11
16
|
|
12
17
|
if IS_PYDANTIC_V2:
|
13
18
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import typing
|
5
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
6
|
+
import pydantic
|
7
|
+
|
8
|
+
|
9
|
+
class PromptVersionBuildConfigSandbox(UniversalBaseModel):
|
10
|
+
source: typing.Literal["SANDBOX"] = "SANDBOX"
|
11
|
+
sandbox_id: str
|
12
|
+
sandbox_snapshot_id: str
|
13
|
+
prompt_id: str
|
14
|
+
|
15
|
+
if IS_PYDANTIC_V2:
|
16
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
17
|
+
else:
|
18
|
+
|
19
|
+
class Config:
|
20
|
+
frozen = True
|
21
|
+
smart_union = True
|
22
|
+
extra = pydantic.Extra.allow
|
@@ -2,12 +2,13 @@ from dataclasses import field
|
|
2
2
|
from functools import cached_property, reduce
|
3
3
|
import inspect
|
4
4
|
from types import MappingProxyType
|
5
|
-
from uuid import UUID
|
5
|
+
from uuid import UUID, uuid4
|
6
6
|
from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union, cast, get_args
|
7
7
|
|
8
8
|
from vellum.workflows.constants import undefined
|
9
9
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
10
10
|
from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
|
11
|
+
from vellum.workflows.edges.edge import Edge
|
11
12
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
12
13
|
from vellum.workflows.exceptions import NodeException
|
13
14
|
from vellum.workflows.graph import Graph
|
@@ -325,6 +326,36 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
325
326
|
code=WorkflowErrorCode.INVALID_INPUTS,
|
326
327
|
)
|
327
328
|
|
329
|
+
@classmethod
|
330
|
+
def _queue_node_execution(
|
331
|
+
cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
|
332
|
+
) -> UUID:
|
333
|
+
"""
|
334
|
+
Queues a future execution of a node, returning the span id of the execution.
|
335
|
+
|
336
|
+
We may combine this into the should_initiate method, but we'll keep it separate for now to avoid
|
337
|
+
breaking changes until the 0.15.0 release.
|
338
|
+
"""
|
339
|
+
|
340
|
+
execution_id = uuid4()
|
341
|
+
if not invoked_by:
|
342
|
+
return execution_id
|
343
|
+
|
344
|
+
if cls.merge_behavior not in {MergeBehavior.AWAIT_ANY, MergeBehavior.AWAIT_ALL}:
|
345
|
+
return execution_id
|
346
|
+
|
347
|
+
source_node = invoked_by.from_port.node_class
|
348
|
+
for queued_node_execution_id in state.meta.node_execution_cache._node_executions_queued[cls.node_class]:
|
349
|
+
if source_node not in state.meta.node_execution_cache._dependencies_invoked[queued_node_execution_id]:
|
350
|
+
state.meta.node_execution_cache._invoke_dependency(
|
351
|
+
queued_node_execution_id, cls.node_class, source_node, dependencies
|
352
|
+
)
|
353
|
+
return queued_node_execution_id
|
354
|
+
|
355
|
+
state.meta.node_execution_cache._node_executions_queued[cls.node_class].append(execution_id)
|
356
|
+
state.meta.node_execution_cache._invoke_dependency(execution_id, cls.node_class, source_node, dependencies)
|
357
|
+
return execution_id
|
358
|
+
|
328
359
|
class Execution(metaclass=_BaseNodeExecutionMeta):
|
329
360
|
node_class: Type["BaseNode"]
|
330
361
|
count: int
|
@@ -882,44 +882,68 @@ def main(arg1: list) -> str:
|
|
882
882
|
assert outputs == {"result": "bar", "log": ""}
|
883
883
|
|
884
884
|
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
885
|
+
@pytest.mark.parametrize(
|
886
|
+
"code_snippet",
|
887
|
+
[
|
888
|
+
"""
|
889
889
|
def main(text: str) -> str:
|
890
890
|
return text.value
|
891
|
-
"""
|
891
|
+
""",
|
892
|
+
"""
|
893
|
+
def main(text: str) -> str:
|
894
|
+
return text["value"]
|
895
|
+
""",
|
896
|
+
],
|
897
|
+
)
|
898
|
+
def test_run_node__string_value_wrapper_value(code_snippet):
|
899
|
+
"""Test string value wrapper value access using different patterns"""
|
900
|
+
|
901
|
+
# GIVEN a node that accesses the 'value' property of a string input
|
902
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, str]):
|
903
|
+
code = code_snippet
|
904
|
+
runtime = "PYTHON_3_11_6"
|
892
905
|
code_inputs = {
|
893
906
|
"text": "hello",
|
894
907
|
}
|
895
|
-
runtime = "PYTHON_3_11_6"
|
896
908
|
|
897
909
|
# WHEN we run the node
|
898
910
|
node = ExampleCodeExecutionNode()
|
899
911
|
outputs = node.run()
|
900
912
|
|
901
|
-
# THEN the node should successfully access the string value
|
913
|
+
# THEN the node should successfully access the string value
|
902
914
|
assert outputs == {"result": "hello", "log": ""}
|
903
915
|
|
904
916
|
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
917
|
+
@pytest.mark.parametrize(
|
918
|
+
"code_snippet",
|
919
|
+
[
|
920
|
+
"""
|
909
921
|
def main(text: str) -> str:
|
910
|
-
return text
|
911
|
-
"""
|
922
|
+
return text.type
|
923
|
+
""",
|
924
|
+
"""
|
925
|
+
def main(text: str) -> str:
|
926
|
+
return text["type"]
|
927
|
+
""",
|
928
|
+
],
|
929
|
+
)
|
930
|
+
def test_run_node__string_value_wrapper_type(code_snippet):
|
931
|
+
"""Test string value wrapper type access using different patterns"""
|
932
|
+
|
933
|
+
# GIVEN a node that will return the string type
|
934
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, str]):
|
935
|
+
code = code_snippet
|
936
|
+
runtime = "PYTHON_3_11_6"
|
912
937
|
code_inputs = {
|
913
938
|
"text": "hello",
|
914
939
|
}
|
915
|
-
runtime = "PYTHON_3_11_6"
|
916
940
|
|
917
941
|
# WHEN we run the node
|
918
942
|
node = ExampleCodeExecutionNode()
|
919
943
|
outputs = node.run()
|
920
944
|
|
921
|
-
# THEN the node should successfully
|
922
|
-
assert outputs == {"result": "
|
945
|
+
# THEN the node should successfully return the string type
|
946
|
+
assert outputs == {"result": "STRING", "log": ""}
|
923
947
|
|
924
948
|
|
925
949
|
def test_run_node__string_value_wrapper__list_of_dicts():
|
@@ -962,6 +986,87 @@ def main(input: str) -> str:
|
|
962
986
|
assert outputs == {"result": "h", "log": ""}
|
963
987
|
|
964
988
|
|
989
|
+
@pytest.mark.parametrize(
|
990
|
+
"code_snippet",
|
991
|
+
[
|
992
|
+
"""
|
993
|
+
from vellum.client.types.function_call import FunctionCall
|
994
|
+
def main(input: FunctionCall) -> FunctionCall:
|
995
|
+
return input.value
|
996
|
+
""",
|
997
|
+
"""
|
998
|
+
from vellum.client.types.function_call import FunctionCall
|
999
|
+
def main(input: FunctionCall) -> FunctionCall:
|
1000
|
+
return input["value"]
|
1001
|
+
""",
|
1002
|
+
],
|
1003
|
+
)
|
1004
|
+
def test_run_node__function_call_wrapper_value(code_snippet):
|
1005
|
+
"""Test function call wrapper value access using different patterns"""
|
1006
|
+
|
1007
|
+
# GIVEN a node that accesses the function call value
|
1008
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, FunctionCall]):
|
1009
|
+
code = code_snippet
|
1010
|
+
code_inputs = {
|
1011
|
+
"input": FunctionCall(
|
1012
|
+
name="test-name",
|
1013
|
+
arguments={
|
1014
|
+
"test-key": "test-value",
|
1015
|
+
},
|
1016
|
+
)
|
1017
|
+
}
|
1018
|
+
|
1019
|
+
# WHEN we run the node
|
1020
|
+
node = ExampleCodeExecutionNode()
|
1021
|
+
outputs = node.run()
|
1022
|
+
|
1023
|
+
# THEN the node should successfully return the function call value
|
1024
|
+
assert isinstance(outputs.result, FunctionCall)
|
1025
|
+
assert outputs.result.name == "test-name"
|
1026
|
+
assert outputs.result.arguments == {"test-key": "test-value"}
|
1027
|
+
assert outputs.result.id is None
|
1028
|
+
assert outputs.log == ""
|
1029
|
+
|
1030
|
+
|
1031
|
+
@pytest.mark.parametrize(
|
1032
|
+
"code_snippet",
|
1033
|
+
[
|
1034
|
+
"""
|
1035
|
+
from vellum.client.types.function_call import FunctionCall
|
1036
|
+
def main(input: FunctionCall) -> str:
|
1037
|
+
return input.type
|
1038
|
+
""",
|
1039
|
+
"""
|
1040
|
+
from vellum.client.types.function_call import FunctionCall
|
1041
|
+
def main(input: FunctionCall) -> str:
|
1042
|
+
return input["type"]
|
1043
|
+
""",
|
1044
|
+
],
|
1045
|
+
)
|
1046
|
+
def test_run_node__function_call_wrapper_type(code_snippet):
|
1047
|
+
"""Test function call wrapper type access using different patterns"""
|
1048
|
+
|
1049
|
+
# GIVEN a node that accesses the function call type
|
1050
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, str]):
|
1051
|
+
code = code_snippet
|
1052
|
+
runtime = "PYTHON_3_11_6"
|
1053
|
+
code_inputs = {
|
1054
|
+
"input": FunctionCall(
|
1055
|
+
name="test-name",
|
1056
|
+
arguments={
|
1057
|
+
"test-key": "test-value",
|
1058
|
+
},
|
1059
|
+
)
|
1060
|
+
}
|
1061
|
+
|
1062
|
+
# WHEN we run the node
|
1063
|
+
node = ExampleCodeExecutionNode()
|
1064
|
+
outputs = node.run()
|
1065
|
+
|
1066
|
+
# THEN the node should successfully return the function call type
|
1067
|
+
assert outputs == {"result": "FUNCTION_CALL", "log": ""}
|
1068
|
+
|
1069
|
+
|
965
1070
|
def test_run_node__iter_list():
|
966
1071
|
# GIVEN a node that will return the first string in a list
|
967
1072
|
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, str]):
|
@@ -1044,3 +1149,21 @@ Traceback (most recent call last):
|
|
1044
1149
|
IndexError: list index out of range
|
1045
1150
|
"""
|
1046
1151
|
)
|
1152
|
+
|
1153
|
+
|
1154
|
+
def test_run_node__default_function_call_type():
|
1155
|
+
# GIVEN a node that will return a FunctionCall
|
1156
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, FunctionCall]):
|
1157
|
+
code = """\
|
1158
|
+
def main(input: str) -> str:
|
1159
|
+
return None
|
1160
|
+
"""
|
1161
|
+
runtime = "PYTHON_3_11_6"
|
1162
|
+
code_inputs = {"input": "foo"}
|
1163
|
+
|
1164
|
+
# WHEN we run the node
|
1165
|
+
node = ExampleCodeExecutionNode()
|
1166
|
+
outputs = node.run()
|
1167
|
+
|
1168
|
+
# THEN the node should return default function call
|
1169
|
+
assert outputs == {"result": FunctionCall(name="", arguments={}), "log": ""}
|
@@ -43,8 +43,8 @@ class ToolCallingNode(BaseNode):
|
|
43
43
|
chat_history: The complete chat history including tool calls
|
44
44
|
"""
|
45
45
|
|
46
|
-
text: str
|
47
|
-
chat_history: List[ChatMessage]
|
46
|
+
text: str
|
47
|
+
chat_history: List[ChatMessage]
|
48
48
|
|
49
49
|
def run(self) -> Outputs:
|
50
50
|
"""
|
@@ -81,10 +81,10 @@ class ToolCallingNode(BaseNode):
|
|
81
81
|
message="Subworkflow unexpectedly paused",
|
82
82
|
)
|
83
83
|
elif terminal_event.name == "workflow.execution.fulfilled":
|
84
|
-
node_outputs = self.Outputs(
|
85
|
-
|
86
|
-
|
87
|
-
|
84
|
+
node_outputs = self.Outputs(
|
85
|
+
text=terminal_event.outputs.text,
|
86
|
+
chat_history=terminal_event.outputs.chat_history,
|
87
|
+
)
|
88
88
|
|
89
89
|
return node_outputs
|
90
90
|
elif terminal_event.name == "workflow.execution.rejected":
|
@@ -11,7 +11,7 @@ from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePro
|
|
11
11
|
from vellum.workflows.outputs.base import BaseOutput
|
12
12
|
from vellum.workflows.ports.port import Port
|
13
13
|
from vellum.workflows.references.lazy import LazyReference
|
14
|
-
from vellum.workflows.types.core import EntityInputsInterface
|
14
|
+
from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
|
15
15
|
|
16
16
|
|
17
17
|
class FunctionNode(BaseNode):
|
@@ -21,6 +21,9 @@ class FunctionNode(BaseNode):
|
|
21
21
|
|
22
22
|
|
23
23
|
class ToolRouterNode(InlinePromptNode):
|
24
|
+
class Trigger(InlinePromptNode.Trigger):
|
25
|
+
merge_behavior = MergeBehavior.AWAIT_ATTRIBUTES
|
26
|
+
|
24
27
|
def run(self) -> Iterator[BaseOutput]:
|
25
28
|
self.prompt_inputs = {**self.prompt_inputs, "chat_history": self.state.chat_history} # type: ignore
|
26
29
|
generator = super().run()
|
@@ -4,7 +4,7 @@ from typing import Any, List, Union
|
|
4
4
|
from pydantic import BaseModel
|
5
5
|
|
6
6
|
from vellum.client.types.chat_message import ChatMessage
|
7
|
-
from vellum.client.types.function_call import FunctionCall
|
7
|
+
from vellum.client.types.function_call import FunctionCall
|
8
8
|
from vellum.client.types.vellum_value import VellumValue
|
9
9
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
10
10
|
from vellum.workflows.exceptions import NodeException
|
@@ -15,11 +15,7 @@ from vellum.workflows.types.core import Json
|
|
15
15
|
class Person(BaseModel):
|
16
16
|
name: str
|
17
17
|
age: int
|
18
|
-
|
19
|
-
|
20
|
-
class FunctionCall(BaseModel):
|
21
|
-
name: str
|
22
|
-
args: List[int]
|
18
|
+
colors: List[str]
|
23
19
|
|
24
20
|
|
25
21
|
@pytest.mark.parametrize(
|
@@ -61,34 +57,25 @@ def test_parse_type_from_str_basic_cases(input_str, output_type, expected_result
|
|
61
57
|
|
62
58
|
|
63
59
|
def test_parse_type_from_str_pydantic_models():
|
64
|
-
person_json = '{"name": "Alice", "age": 30}'
|
60
|
+
person_json = '{"name": "Alice", "age": 30, "colors": ["red", "blue"]}'
|
65
61
|
person = parse_type_from_str(person_json, Person)
|
66
62
|
assert isinstance(person, Person)
|
67
63
|
assert person.name == "Alice"
|
68
64
|
assert person.age == 30
|
69
|
-
|
70
|
-
function_json = '{"name": "test", "args": [1, 2]}'
|
71
|
-
function = parse_type_from_str(function_json, FunctionCall)
|
72
|
-
assert isinstance(function, FunctionCall)
|
73
|
-
assert function.name == "test"
|
74
|
-
assert function.args == [1, 2]
|
75
|
-
|
76
|
-
function_call_json = '{"value": {"name": "test", "args": [1, 2]}}'
|
77
|
-
function = parse_type_from_str(function_call_json, FunctionCall)
|
78
|
-
assert isinstance(function, FunctionCall)
|
79
|
-
assert function.name == "test"
|
80
|
-
assert function.args == [1, 2]
|
65
|
+
assert person.colors == ["red", "blue"]
|
81
66
|
|
82
67
|
|
83
68
|
def test_parse_type_from_str_list_of_models():
|
84
|
-
person_list_json = '[{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]'
|
69
|
+
person_list_json = '[{"name": "Alice", "age": 30, "colors": ["red", "blue"]}, {"name": "Bob", "age": 25, "colors": ["green", "yellow"]}]' # noqa: E501
|
85
70
|
persons = parse_type_from_str(person_list_json, List[Person])
|
86
71
|
assert len(persons) == 2
|
87
72
|
assert all(isinstance(p, Person) for p in persons)
|
88
73
|
assert persons[0].name == "Alice"
|
89
74
|
assert persons[0].age == 30
|
75
|
+
assert persons[0].colors == ["red", "blue"]
|
90
76
|
assert persons[1].name == "Bob"
|
91
77
|
assert persons[1].age == 25
|
78
|
+
assert persons[1].colors == ["green", "yellow"]
|
92
79
|
|
93
80
|
|
94
81
|
@pytest.mark.parametrize(
|
@@ -143,7 +130,7 @@ def test_parse_type_from_str_error_cases(input_str, output_type, expected_except
|
|
143
130
|
(int, 0), # Number
|
144
131
|
(float, 0.0), # Number
|
145
132
|
(Any, None), # Json
|
146
|
-
(
|
133
|
+
(FunctionCall, FunctionCall(name="", arguments={})), # FunctionCall
|
147
134
|
(List[ChatMessage], []), # Chat History
|
148
135
|
(List[VellumValue], []), # Array
|
149
136
|
(Union[float, int], 0.0), # Union
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -230,7 +230,10 @@ def _get_default_value(output_type: Any) -> Any:
|
|
230
230
|
elif origin is list:
|
231
231
|
return []
|
232
232
|
elif output_type is FunctionCall:
|
233
|
-
return
|
233
|
+
return FunctionCall(
|
234
|
+
name="",
|
235
|
+
arguments={},
|
236
|
+
)
|
234
237
|
elif origin is Union:
|
235
238
|
# Always use the first argument type's default value
|
236
239
|
if args:
|
@@ -431,7 +431,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
431
431
|
return
|
432
432
|
|
433
433
|
all_deps = self._dependencies[node_class]
|
434
|
-
node_span_id =
|
434
|
+
node_span_id = node_class.Trigger._queue_node_execution(state, all_deps, invoked_by)
|
435
435
|
if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
|
436
436
|
return
|
437
437
|
|