vellum-ai 0.12.3__py3-none-any.whl → 0.12.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/resources/workflows/client.py +32 -0
- vellum/client/types/chat_message_prompt_block.py +1 -1
- vellum/client/types/function_definition.py +26 -7
- vellum/client/types/jinja_prompt_block.py +1 -1
- vellum/client/types/plain_text_prompt_block.py +1 -1
- vellum/client/types/rich_text_prompt_block.py +1 -1
- vellum/client/types/variable_prompt_block.py +1 -1
- vellum/plugins/vellum_mypy.py +20 -23
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +21 -8
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +64 -0
- vellum/workflows/sandbox.py +51 -0
- vellum/workflows/tests/__init__.py +0 -0
- vellum/workflows/tests/test_sandbox.py +62 -0
- vellum/workflows/types/core.py +2 -52
- vellum/workflows/utils/functions.py +41 -4
- vellum/workflows/utils/tests/test_functions.py +93 -0
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/METADATA +1 -1
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/RECORD +33 -28
- vellum_cli/__init__.py +14 -0
- vellum_cli/pull.py +16 -2
- vellum_cli/tests/test_pull.py +45 -0
- vellum_ee/workflows/display/nodes/base_node_vellum_display.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -3
- vellum_ee/workflows/display/nodes/vellum/conditional_node.py +18 -18
- vellum_ee/workflows/display/nodes/vellum/search_node.py +19 -15
- vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/utils.py +4 -4
- vellum_ee/workflows/display/utils/vellum.py +1 -2
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/LICENSE +0 -0
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/WHEEL +0 -0
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.12.
|
21
|
+
"X-Fern-SDK-Version": "0.12.5",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -27,7 +27,11 @@ class WorkflowsClient:
|
|
27
27
|
self,
|
28
28
|
id: str,
|
29
29
|
*,
|
30
|
+
exclude_code: typing.Optional[bool] = None,
|
30
31
|
format: typing.Optional[WorkflowsPullRequestFormat] = None,
|
32
|
+
include_json: typing.Optional[bool] = None,
|
33
|
+
include_sandbox: typing.Optional[bool] = None,
|
34
|
+
strict: typing.Optional[bool] = None,
|
31
35
|
request_options: typing.Optional[RequestOptions] = None,
|
32
36
|
) -> typing.Iterator[bytes]:
|
33
37
|
"""
|
@@ -38,8 +42,16 @@ class WorkflowsClient:
|
|
38
42
|
id : str
|
39
43
|
The ID of the Workflow to pull from
|
40
44
|
|
45
|
+
exclude_code : typing.Optional[bool]
|
46
|
+
|
41
47
|
format : typing.Optional[WorkflowsPullRequestFormat]
|
42
48
|
|
49
|
+
include_json : typing.Optional[bool]
|
50
|
+
|
51
|
+
include_sandbox : typing.Optional[bool]
|
52
|
+
|
53
|
+
strict : typing.Optional[bool]
|
54
|
+
|
43
55
|
request_options : typing.Optional[RequestOptions]
|
44
56
|
Request-specific configuration.
|
45
57
|
|
@@ -53,7 +65,11 @@ class WorkflowsClient:
|
|
53
65
|
base_url=self._client_wrapper.get_environment().default,
|
54
66
|
method="GET",
|
55
67
|
params={
|
68
|
+
"exclude_code": exclude_code,
|
56
69
|
"format": format,
|
70
|
+
"include_json": include_json,
|
71
|
+
"include_sandbox": include_sandbox,
|
72
|
+
"strict": strict,
|
57
73
|
},
|
58
74
|
request_options=request_options,
|
59
75
|
) as _response:
|
@@ -164,7 +180,11 @@ class AsyncWorkflowsClient:
|
|
164
180
|
self,
|
165
181
|
id: str,
|
166
182
|
*,
|
183
|
+
exclude_code: typing.Optional[bool] = None,
|
167
184
|
format: typing.Optional[WorkflowsPullRequestFormat] = None,
|
185
|
+
include_json: typing.Optional[bool] = None,
|
186
|
+
include_sandbox: typing.Optional[bool] = None,
|
187
|
+
strict: typing.Optional[bool] = None,
|
168
188
|
request_options: typing.Optional[RequestOptions] = None,
|
169
189
|
) -> typing.AsyncIterator[bytes]:
|
170
190
|
"""
|
@@ -175,8 +195,16 @@ class AsyncWorkflowsClient:
|
|
175
195
|
id : str
|
176
196
|
The ID of the Workflow to pull from
|
177
197
|
|
198
|
+
exclude_code : typing.Optional[bool]
|
199
|
+
|
178
200
|
format : typing.Optional[WorkflowsPullRequestFormat]
|
179
201
|
|
202
|
+
include_json : typing.Optional[bool]
|
203
|
+
|
204
|
+
include_sandbox : typing.Optional[bool]
|
205
|
+
|
206
|
+
strict : typing.Optional[bool]
|
207
|
+
|
180
208
|
request_options : typing.Optional[RequestOptions]
|
181
209
|
Request-specific configuration.
|
182
210
|
|
@@ -190,7 +218,11 @@ class AsyncWorkflowsClient:
|
|
190
218
|
base_url=self._client_wrapper.get_environment().default,
|
191
219
|
method="GET",
|
192
220
|
params={
|
221
|
+
"exclude_code": exclude_code,
|
193
222
|
"format": format,
|
223
|
+
"include_json": include_json,
|
224
|
+
"include_sandbox": include_sandbox,
|
225
|
+
"strict": strict,
|
194
226
|
},
|
195
227
|
request_options=request_options,
|
196
228
|
) as _response:
|
@@ -16,9 +16,9 @@ class ChatMessagePromptBlock(UniversalBaseModel):
|
|
16
16
|
A block that represents a chat message in a prompt template.
|
17
17
|
"""
|
18
18
|
|
19
|
+
block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
|
19
20
|
state: typing.Optional[PromptBlockState] = None
|
20
21
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
21
|
-
block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
|
22
22
|
chat_role: ChatMessageRole
|
23
23
|
chat_source: typing.Optional[str] = None
|
24
24
|
chat_message_unterminated: typing.Optional[bool] = None
|
@@ -4,22 +4,41 @@ from ..core.pydantic_utilities import UniversalBaseModel
|
|
4
4
|
import typing
|
5
5
|
from .prompt_block_state import PromptBlockState
|
6
6
|
from .ephemeral_prompt_cache_config import EphemeralPromptCacheConfig
|
7
|
-
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
7
|
import pydantic
|
8
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
9
9
|
|
10
10
|
|
11
11
|
class FunctionDefinition(UniversalBaseModel):
|
12
12
|
"""
|
13
|
-
|
13
|
+
The definition of a Function (aka "Tool Call") that a Prompt/Model has access to.
|
14
14
|
"""
|
15
15
|
|
16
16
|
state: typing.Optional[PromptBlockState] = None
|
17
17
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
name: typing.Optional[str] = None
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
18
|
+
name: typing.Optional[str] = pydantic.Field(default=None)
|
19
|
+
"""
|
20
|
+
The name identifying the function.
|
21
|
+
"""
|
22
|
+
|
23
|
+
description: typing.Optional[str] = pydantic.Field(default=None)
|
24
|
+
"""
|
25
|
+
A description to help guide the model when to invoke this function.
|
26
|
+
"""
|
27
|
+
|
28
|
+
parameters: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = pydantic.Field(default=None)
|
29
|
+
"""
|
30
|
+
An OpenAPI specification of parameters that are supported by this function.
|
31
|
+
"""
|
32
|
+
|
33
|
+
forced: typing.Optional[bool] = pydantic.Field(default=None)
|
34
|
+
"""
|
35
|
+
Set this option to true to force the model to return a function call of this function.
|
36
|
+
"""
|
37
|
+
|
38
|
+
strict: typing.Optional[bool] = pydantic.Field(default=None)
|
39
|
+
"""
|
40
|
+
Set this option to use strict schema decoding when available.
|
41
|
+
"""
|
23
42
|
|
24
43
|
if IS_PYDANTIC_V2:
|
25
44
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -13,9 +13,9 @@ class JinjaPromptBlock(UniversalBaseModel):
|
|
13
13
|
A block of Jinja template code that is used to generate a prompt
|
14
14
|
"""
|
15
15
|
|
16
|
+
block_type: typing.Literal["JINJA"] = "JINJA"
|
16
17
|
state: typing.Optional[PromptBlockState] = None
|
17
18
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
block_type: typing.Literal["JINJA"] = "JINJA"
|
19
19
|
template: str
|
20
20
|
|
21
21
|
if IS_PYDANTIC_V2:
|
@@ -13,9 +13,9 @@ class PlainTextPromptBlock(UniversalBaseModel):
|
|
13
13
|
A block that holds a plain text string value.
|
14
14
|
"""
|
15
15
|
|
16
|
+
block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
|
16
17
|
state: typing.Optional[PromptBlockState] = None
|
17
18
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
|
19
19
|
text: str
|
20
20
|
|
21
21
|
if IS_PYDANTIC_V2:
|
@@ -14,9 +14,9 @@ class RichTextPromptBlock(UniversalBaseModel):
|
|
14
14
|
A block that includes a combination of plain text and variable blocks.
|
15
15
|
"""
|
16
16
|
|
17
|
+
block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
|
17
18
|
state: typing.Optional[PromptBlockState] = None
|
18
19
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
19
|
-
block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
|
20
20
|
blocks: typing.List[RichTextChildBlock]
|
21
21
|
|
22
22
|
if IS_PYDANTIC_V2:
|
@@ -13,9 +13,9 @@ class VariablePromptBlock(UniversalBaseModel):
|
|
13
13
|
A block that represents a variable in a prompt template.
|
14
14
|
"""
|
15
15
|
|
16
|
+
block_type: typing.Literal["VARIABLE"] = "VARIABLE"
|
16
17
|
state: typing.Optional[PromptBlockState] = None
|
17
18
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
block_type: typing.Literal["VARIABLE"] = "VARIABLE"
|
19
19
|
input_variable: str
|
20
20
|
|
21
21
|
if IS_PYDANTIC_V2:
|
vellum/plugins/vellum_mypy.py
CHANGED
@@ -172,41 +172,39 @@ class VellumMypyPlugin(Plugin):
|
|
172
172
|
|
173
173
|
def _dynamic_output_node_class_hook(self, ctx: ClassDefContext, attribute_name: str) -> None:
|
174
174
|
"""
|
175
|
-
We use this hook to properly annotate the Outputs class for Templating Node
|
176
|
-
|
175
|
+
We use this hook to properly annotate the Outputs class for Templating Node and
|
176
|
+
Code Execution Node using the resolved type
|
177
|
+
of the TemplatingNode's and CodeExecutionNode's class _OutputType generic.
|
177
178
|
"""
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
if not templating_node_bases:
|
179
|
+
node_info = ctx.cls.info
|
180
|
+
node_bases = ctx.cls.info.bases
|
181
|
+
if not node_bases:
|
182
182
|
return
|
183
|
-
if not isinstance(
|
183
|
+
if not isinstance(node_bases[0], Instance):
|
184
184
|
return
|
185
185
|
|
186
|
-
|
187
|
-
|
188
|
-
if not _is_subclass(base_templating_node, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
|
189
|
-
return
|
186
|
+
base_args = node_bases[0].args
|
187
|
+
base_node = node_bases[0].type
|
190
188
|
|
191
|
-
if len(
|
189
|
+
if len(base_args) != 2:
|
192
190
|
return
|
193
191
|
|
194
|
-
|
195
|
-
if isinstance(
|
196
|
-
|
192
|
+
base_node_resolved_type = base_args[1]
|
193
|
+
if isinstance(base_node_resolved_type, AnyType):
|
194
|
+
base_node_resolved_type = ctx.api.named_type("builtins.str")
|
197
195
|
|
198
|
-
|
199
|
-
if not
|
196
|
+
base_node_outputs = base_node.names.get("Outputs")
|
197
|
+
if not base_node_outputs:
|
200
198
|
return
|
201
199
|
|
202
|
-
|
203
|
-
if not
|
204
|
-
|
205
|
-
new_outputs_sym =
|
200
|
+
current_node_outputs = node_info.names.get("Outputs")
|
201
|
+
if not current_node_outputs:
|
202
|
+
node_info.names["Outputs"] = base_node_outputs.copy()
|
203
|
+
new_outputs_sym = node_info.names["Outputs"].node
|
206
204
|
if isinstance(new_outputs_sym, TypeInfo):
|
207
205
|
result_sym = new_outputs_sym.names[attribute_name].node
|
208
206
|
if isinstance(result_sym, Var):
|
209
|
-
result_sym.type =
|
207
|
+
result_sym.type = base_node_resolved_type
|
210
208
|
|
211
209
|
def _base_node_class_hook(self, ctx: ClassDefContext) -> None:
|
212
210
|
"""
|
@@ -233,7 +231,6 @@ class VellumMypyPlugin(Plugin):
|
|
233
231
|
type_ = sym.node.type
|
234
232
|
if not type_:
|
235
233
|
continue
|
236
|
-
|
237
234
|
sym.node.type = self._get_resolvable_type(
|
238
235
|
lambda fullname, types: ctx.api.named_type(fullname, types), type_
|
239
236
|
)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
from uuid import uuid4
|
2
3
|
from typing import ClassVar, Generic, Iterator, List, Optional, Tuple, cast
|
3
4
|
|
@@ -18,6 +19,7 @@ from vellum.client import RequestOptions
|
|
18
19
|
from vellum.workflows.constants import OMIT
|
19
20
|
from vellum.workflows.context import get_parent_context
|
20
21
|
from vellum.workflows.errors import WorkflowErrorCode
|
22
|
+
from vellum.workflows.events.types import default_serializer
|
21
23
|
from vellum.workflows.exceptions import NodeException
|
22
24
|
from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
|
23
25
|
from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
|
@@ -108,19 +110,30 @@ class BaseInlinePromptNode(BasePromptNode, Generic[StateType]):
|
|
108
110
|
value=cast(List[ChatMessage], input_value),
|
109
111
|
)
|
110
112
|
)
|
111
|
-
|
112
|
-
|
113
|
-
|
113
|
+
else:
|
114
|
+
try:
|
115
|
+
input_value = default_serializer(input_value)
|
116
|
+
except json.JSONDecodeError as e:
|
117
|
+
raise NodeException(
|
118
|
+
message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
|
119
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
120
|
+
)
|
121
|
+
|
122
|
+
input_variables.append(
|
123
|
+
VellumVariable(
|
124
|
+
# TODO: Determine whether or not we actually need an id here and if we do,
|
125
|
+
# figure out how to maintain stable id references.
|
126
|
+
# https://app.shortcut.com/vellum/story/4080
|
127
|
+
id=str(uuid4()),
|
128
|
+
key=input_name,
|
129
|
+
type="JSON",
|
130
|
+
)
|
131
|
+
)
|
114
132
|
input_values.append(
|
115
133
|
PromptRequestJsonInput(
|
116
134
|
key=input_name,
|
117
135
|
value=input_value,
|
118
136
|
)
|
119
137
|
)
|
120
|
-
else:
|
121
|
-
raise NodeException(
|
122
|
-
message=f"Unrecognized input type for input '{input_name}': {input_value.__class__}",
|
123
|
-
code=WorkflowErrorCode.INVALID_INPUTS,
|
124
|
-
)
|
125
138
|
|
126
139
|
return input_variables, input_values
|
File without changes
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from uuid import uuid4
|
3
|
+
from typing import Any, Iterator, List
|
4
|
+
|
5
|
+
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
6
|
+
from vellum.client.types.execute_prompt_event import ExecutePromptEvent
|
7
|
+
from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
|
8
|
+
from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
|
9
|
+
from vellum.client.types.prompt_output import PromptOutput
|
10
|
+
from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
|
11
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
12
|
+
from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
|
13
|
+
|
14
|
+
|
15
|
+
def test_inline_prompt_node__json_inputs(vellum_adhoc_prompt_client):
|
16
|
+
# GIVEN a prompt node with various inputs
|
17
|
+
@dataclass
|
18
|
+
class MyDataClass:
|
19
|
+
hello: str
|
20
|
+
|
21
|
+
class MyPydantic(UniversalBaseModel):
|
22
|
+
example: str
|
23
|
+
|
24
|
+
class MyNode(InlinePromptNode):
|
25
|
+
ml_model = "gpt-4o"
|
26
|
+
blocks = []
|
27
|
+
prompt_inputs = {
|
28
|
+
"a_dict": {"foo": "bar"},
|
29
|
+
"a_list": [1, 2, 3],
|
30
|
+
"a_dataclass": MyDataClass(hello="world"),
|
31
|
+
"a_pydantic": MyPydantic(example="example"),
|
32
|
+
}
|
33
|
+
|
34
|
+
# AND a known response from invoking an inline prompt
|
35
|
+
expected_outputs: List[PromptOutput] = [
|
36
|
+
StringVellumValue(value="Test"),
|
37
|
+
]
|
38
|
+
|
39
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
40
|
+
execution_id = str(uuid4())
|
41
|
+
events: List[ExecutePromptEvent] = [
|
42
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
43
|
+
FulfilledExecutePromptEvent(
|
44
|
+
execution_id=execution_id,
|
45
|
+
outputs=expected_outputs,
|
46
|
+
),
|
47
|
+
]
|
48
|
+
yield from events
|
49
|
+
|
50
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
51
|
+
|
52
|
+
# WHEN the node is run
|
53
|
+
list(MyNode().run())
|
54
|
+
|
55
|
+
# THEN the prompt is executed with the correct inputs
|
56
|
+
mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
|
57
|
+
assert mock_api.call_count == 1
|
58
|
+
assert mock_api.call_args.kwargs["input_values"] == [
|
59
|
+
PromptRequestJsonInput(key="a_dict", type="JSON", value={"foo": "bar"}),
|
60
|
+
PromptRequestJsonInput(key="a_list", type="JSON", value=[1, 2, 3]),
|
61
|
+
PromptRequestJsonInput(key="a_dataclass", type="JSON", value={"hello": "world"}),
|
62
|
+
PromptRequestJsonInput(key="a_pydantic", type="JSON", value={"example": "example"}),
|
63
|
+
]
|
64
|
+
assert len(mock_api.call_args.kwargs["input_variables"]) == 4
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import Generic, Sequence, Type
|
2
|
+
|
3
|
+
import dotenv
|
4
|
+
|
5
|
+
from vellum.workflows.events.workflow import WorkflowEventStream
|
6
|
+
from vellum.workflows.inputs.base import BaseInputs
|
7
|
+
from vellum.workflows.logging import load_logger
|
8
|
+
from vellum.workflows.types.generics import WorkflowType
|
9
|
+
from vellum.workflows.workflows.event_filters import root_workflow_event_filter
|
10
|
+
|
11
|
+
|
12
|
+
class SandboxRunner(Generic[WorkflowType]):
|
13
|
+
def __init__(self, workflow: Type[WorkflowType], inputs: Sequence[BaseInputs]):
|
14
|
+
if not inputs:
|
15
|
+
raise ValueError("Inputs are required to have at least one defined inputs")
|
16
|
+
|
17
|
+
self._workflow = workflow
|
18
|
+
self._inputs = inputs
|
19
|
+
|
20
|
+
dotenv.load_dotenv()
|
21
|
+
self._logger = load_logger()
|
22
|
+
|
23
|
+
def run(self, index: int = 0):
|
24
|
+
if index < 0:
|
25
|
+
self._logger.warning("Index is less than 0, running first input")
|
26
|
+
index = 0
|
27
|
+
elif index >= len(self._inputs):
|
28
|
+
self._logger.warning("Index is greater than the number of provided inputs, running last input")
|
29
|
+
index = len(self._inputs) - 1
|
30
|
+
|
31
|
+
selected_inputs = self._inputs[index]
|
32
|
+
|
33
|
+
workflow = self._workflow()
|
34
|
+
events = workflow.stream(
|
35
|
+
inputs=selected_inputs,
|
36
|
+
event_filter=root_workflow_event_filter,
|
37
|
+
)
|
38
|
+
|
39
|
+
self._process_events(events)
|
40
|
+
|
41
|
+
def _process_events(self, events: WorkflowEventStream):
|
42
|
+
for event in events:
|
43
|
+
if event.name == "workflow.execution.fulfilled":
|
44
|
+
self._logger.info("Workflow fulfilled!")
|
45
|
+
for output_reference, value in event.outputs:
|
46
|
+
self._logger.info("----------------------------------")
|
47
|
+
self._logger.info(f"{output_reference.name}: {value}")
|
48
|
+
elif event.name == "node.execution.initiated":
|
49
|
+
self._logger.info(f"Just started Node: {event.node_definition.__name__}")
|
50
|
+
elif event.name == "node.execution.fulfilled":
|
51
|
+
self._logger.info(f"Just finished Node: {event.node_definition.__name__}")
|
File without changes
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import pytest
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from vellum.workflows.inputs.base import BaseInputs
|
5
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
6
|
+
from vellum.workflows.sandbox import SandboxRunner
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture
|
12
|
+
def mock_logger(mocker):
|
13
|
+
return mocker.patch("vellum.workflows.sandbox.load_logger")
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.mark.parametrize(
|
17
|
+
["run_kwargs", "expected_last_log"],
|
18
|
+
[
|
19
|
+
({}, "final_results: first"),
|
20
|
+
({"index": 1}, "final_results: second"),
|
21
|
+
({"index": -4}, "final_results: first"),
|
22
|
+
({"index": 100}, "final_results: second"),
|
23
|
+
],
|
24
|
+
ids=["default", "specific", "negative", "out_of_bounds"],
|
25
|
+
)
|
26
|
+
def test_sandbox_runner__happy_path(mock_logger, run_kwargs, expected_last_log):
|
27
|
+
# GIVEN we capture the logs to stdout
|
28
|
+
logs = []
|
29
|
+
mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)
|
30
|
+
|
31
|
+
# AND an example workflow
|
32
|
+
class Inputs(BaseInputs):
|
33
|
+
foo: str
|
34
|
+
|
35
|
+
class StartNode(BaseNode):
|
36
|
+
class Outputs(BaseNode.Outputs):
|
37
|
+
bar = Inputs.foo
|
38
|
+
|
39
|
+
class Workflow(BaseWorkflow[Inputs, BaseState]):
|
40
|
+
graph = StartNode
|
41
|
+
|
42
|
+
class Outputs(BaseWorkflow.Outputs):
|
43
|
+
final_results = StartNode.Outputs.bar
|
44
|
+
|
45
|
+
# AND a dataset for this workflow
|
46
|
+
inputs: List[Inputs] = [
|
47
|
+
Inputs(foo="first"),
|
48
|
+
Inputs(foo="second"),
|
49
|
+
]
|
50
|
+
|
51
|
+
# WHEN we run the sandbox
|
52
|
+
runner = SandboxRunner(workflow=Workflow, inputs=inputs)
|
53
|
+
runner.run(**run_kwargs)
|
54
|
+
|
55
|
+
# THEN we see the logs
|
56
|
+
assert logs == [
|
57
|
+
"Just started Node: StartNode",
|
58
|
+
"Just finished Node: StartNode",
|
59
|
+
"Workflow fulfilled!",
|
60
|
+
"----------------------------------",
|
61
|
+
expected_last_log,
|
62
|
+
]
|
vellum/workflows/types/core.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from enum import Enum
|
2
2
|
from typing import ( # type: ignore[attr-defined]
|
3
|
+
Any,
|
3
4
|
Dict,
|
4
5
|
List,
|
5
6
|
Union,
|
@@ -8,22 +9,6 @@ from typing import ( # type: ignore[attr-defined]
|
|
8
9
|
_UnionGenericAlias,
|
9
10
|
)
|
10
11
|
|
11
|
-
from vellum import (
|
12
|
-
ChatMessage,
|
13
|
-
FunctionCall,
|
14
|
-
FunctionCallRequest,
|
15
|
-
SearchResult,
|
16
|
-
SearchResultRequest,
|
17
|
-
VellumAudio,
|
18
|
-
VellumAudioRequest,
|
19
|
-
VellumError,
|
20
|
-
VellumErrorRequest,
|
21
|
-
VellumImage,
|
22
|
-
VellumImageRequest,
|
23
|
-
VellumValue,
|
24
|
-
VellumValueRequest,
|
25
|
-
)
|
26
|
-
|
27
12
|
JsonArray = List["Json"]
|
28
13
|
JsonObject = Dict[str, "Json"]
|
29
14
|
Json = Union[None, bool, int, float, str, JsonArray, JsonObject]
|
@@ -42,42 +27,7 @@ class VellumSecret:
|
|
42
27
|
self.name = name
|
43
28
|
|
44
29
|
|
45
|
-
|
46
|
-
# String inputs
|
47
|
-
str,
|
48
|
-
# Chat history inputs
|
49
|
-
List[ChatMessage],
|
50
|
-
List[ChatMessage],
|
51
|
-
# Search results inputs
|
52
|
-
List[SearchResultRequest],
|
53
|
-
List[SearchResult],
|
54
|
-
# JSON inputs
|
55
|
-
Json,
|
56
|
-
# Number inputs
|
57
|
-
float,
|
58
|
-
# Function Call Inputs
|
59
|
-
FunctionCall,
|
60
|
-
FunctionCallRequest,
|
61
|
-
# Error Inputs
|
62
|
-
VellumError,
|
63
|
-
VellumErrorRequest,
|
64
|
-
# Array Inputs
|
65
|
-
List[VellumValueRequest],
|
66
|
-
List[VellumValue],
|
67
|
-
# Image Inputs
|
68
|
-
VellumImage,
|
69
|
-
VellumImageRequest,
|
70
|
-
# Audio Inputs
|
71
|
-
VellumAudio,
|
72
|
-
VellumAudioRequest,
|
73
|
-
# Vellum Secrets
|
74
|
-
VellumSecret,
|
75
|
-
]
|
76
|
-
|
77
|
-
EntityInputsInterface = Dict[
|
78
|
-
str,
|
79
|
-
VellumValuePrimitive,
|
80
|
-
]
|
30
|
+
EntityInputsInterface = Dict[str, Any]
|
81
31
|
|
82
32
|
|
83
33
|
class MergeBehavior(Enum):
|
@@ -1,6 +1,9 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Union, get_args, get_origin
|
3
|
+
from typing import Any, Callable, Optional, Union, get_args, get_origin
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
from pydantic_core import PydanticUndefined
|
4
7
|
|
5
8
|
from vellum.client.types.function_definition import FunctionDefinition
|
6
9
|
|
@@ -16,7 +19,10 @@ type_map = {
|
|
16
19
|
}
|
17
20
|
|
18
21
|
|
19
|
-
def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
|
22
|
+
def _compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
|
23
|
+
if annotation is None:
|
24
|
+
return {"type": "null"}
|
25
|
+
|
20
26
|
if get_origin(annotation) is Union:
|
21
27
|
return {"anyOf": [_compile_annotation(a, defs) for a in get_args(annotation)]}
|
22
28
|
|
@@ -37,13 +43,44 @@ def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
|
|
37
43
|
if field.default is dataclasses.MISSING:
|
38
44
|
required.append(field.name)
|
39
45
|
else:
|
40
|
-
properties[field.name]["default"] = field.default
|
46
|
+
properties[field.name]["default"] = _compile_default_value(field.default)
|
47
|
+
defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
|
48
|
+
return {"$ref": f"#/$defs/{annotation.__name__}"}
|
49
|
+
|
50
|
+
if issubclass(annotation, BaseModel):
|
51
|
+
if annotation.__name__ not in defs:
|
52
|
+
properties = {}
|
53
|
+
required = []
|
54
|
+
for field_name, field in annotation.model_fields.items():
|
55
|
+
# Mypy is incorrect here, the `annotation` attribute is defined on `FieldInfo`
|
56
|
+
field_annotation = field.annotation # type: ignore[attr-defined]
|
57
|
+
properties[field_name] = _compile_annotation(field_annotation, defs)
|
58
|
+
if field.default is PydanticUndefined:
|
59
|
+
required.append(field_name)
|
60
|
+
else:
|
61
|
+
properties[field_name]["default"] = _compile_default_value(field.default)
|
41
62
|
defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
|
63
|
+
|
42
64
|
return {"$ref": f"#/$defs/{annotation.__name__}"}
|
43
65
|
|
44
66
|
return {"type": type_map[annotation]}
|
45
67
|
|
46
68
|
|
69
|
+
def _compile_default_value(default: Any) -> Any:
|
70
|
+
if dataclasses.is_dataclass(default):
|
71
|
+
return {
|
72
|
+
field.name: _compile_default_value(getattr(default, field.name)) for field in dataclasses.fields(default)
|
73
|
+
}
|
74
|
+
|
75
|
+
if isinstance(default, BaseModel):
|
76
|
+
return {
|
77
|
+
field_name: _compile_default_value(getattr(default, field_name))
|
78
|
+
for field_name in default.model_fields.keys()
|
79
|
+
}
|
80
|
+
|
81
|
+
return default
|
82
|
+
|
83
|
+
|
47
84
|
def compile_function_definition(function: Callable) -> FunctionDefinition:
|
48
85
|
"""
|
49
86
|
Converts a Python function into our Vellum-native FunctionDefinition type.
|
@@ -62,7 +99,7 @@ def compile_function_definition(function: Callable) -> FunctionDefinition:
|
|
62
99
|
if param.default is inspect.Parameter.empty:
|
63
100
|
required.append(param.name)
|
64
101
|
else:
|
65
|
-
properties[param.name]["default"] = param.default
|
102
|
+
properties[param.name]["default"] = _compile_default_value(param.default)
|
66
103
|
|
67
104
|
parameters = {"type": "object", "properties": properties, "required": required}
|
68
105
|
if defs:
|