vellum-ai 0.12.3__py3-none-any.whl → 0.12.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/resources/workflows/client.py +32 -0
- vellum/client/types/chat_message_prompt_block.py +1 -1
- vellum/client/types/function_definition.py +26 -7
- vellum/client/types/jinja_prompt_block.py +1 -1
- vellum/client/types/plain_text_prompt_block.py +1 -1
- vellum/client/types/rich_text_prompt_block.py +1 -1
- vellum/client/types/variable_prompt_block.py +1 -1
- vellum/plugins/vellum_mypy.py +20 -23
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +21 -8
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +64 -0
- vellum/workflows/sandbox.py +51 -0
- vellum/workflows/tests/__init__.py +0 -0
- vellum/workflows/tests/test_sandbox.py +62 -0
- vellum/workflows/types/core.py +2 -52
- vellum/workflows/utils/functions.py +41 -4
- vellum/workflows/utils/tests/test_functions.py +93 -0
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/METADATA +1 -1
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/RECORD +33 -28
- vellum_cli/__init__.py +14 -0
- vellum_cli/pull.py +16 -2
- vellum_cli/tests/test_pull.py +45 -0
- vellum_ee/workflows/display/nodes/base_node_vellum_display.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -3
- vellum_ee/workflows/display/nodes/vellum/conditional_node.py +18 -18
- vellum_ee/workflows/display/nodes/vellum/search_node.py +19 -15
- vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/utils.py +4 -4
- vellum_ee/workflows/display/utils/vellum.py +1 -2
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/LICENSE +0 -0
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/WHEEL +0 -0
- {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.12.
|
21
|
+
"X-Fern-SDK-Version": "0.12.5",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -27,7 +27,11 @@ class WorkflowsClient:
|
|
27
27
|
self,
|
28
28
|
id: str,
|
29
29
|
*,
|
30
|
+
exclude_code: typing.Optional[bool] = None,
|
30
31
|
format: typing.Optional[WorkflowsPullRequestFormat] = None,
|
32
|
+
include_json: typing.Optional[bool] = None,
|
33
|
+
include_sandbox: typing.Optional[bool] = None,
|
34
|
+
strict: typing.Optional[bool] = None,
|
31
35
|
request_options: typing.Optional[RequestOptions] = None,
|
32
36
|
) -> typing.Iterator[bytes]:
|
33
37
|
"""
|
@@ -38,8 +42,16 @@ class WorkflowsClient:
|
|
38
42
|
id : str
|
39
43
|
The ID of the Workflow to pull from
|
40
44
|
|
45
|
+
exclude_code : typing.Optional[bool]
|
46
|
+
|
41
47
|
format : typing.Optional[WorkflowsPullRequestFormat]
|
42
48
|
|
49
|
+
include_json : typing.Optional[bool]
|
50
|
+
|
51
|
+
include_sandbox : typing.Optional[bool]
|
52
|
+
|
53
|
+
strict : typing.Optional[bool]
|
54
|
+
|
43
55
|
request_options : typing.Optional[RequestOptions]
|
44
56
|
Request-specific configuration.
|
45
57
|
|
@@ -53,7 +65,11 @@ class WorkflowsClient:
|
|
53
65
|
base_url=self._client_wrapper.get_environment().default,
|
54
66
|
method="GET",
|
55
67
|
params={
|
68
|
+
"exclude_code": exclude_code,
|
56
69
|
"format": format,
|
70
|
+
"include_json": include_json,
|
71
|
+
"include_sandbox": include_sandbox,
|
72
|
+
"strict": strict,
|
57
73
|
},
|
58
74
|
request_options=request_options,
|
59
75
|
) as _response:
|
@@ -164,7 +180,11 @@ class AsyncWorkflowsClient:
|
|
164
180
|
self,
|
165
181
|
id: str,
|
166
182
|
*,
|
183
|
+
exclude_code: typing.Optional[bool] = None,
|
167
184
|
format: typing.Optional[WorkflowsPullRequestFormat] = None,
|
185
|
+
include_json: typing.Optional[bool] = None,
|
186
|
+
include_sandbox: typing.Optional[bool] = None,
|
187
|
+
strict: typing.Optional[bool] = None,
|
168
188
|
request_options: typing.Optional[RequestOptions] = None,
|
169
189
|
) -> typing.AsyncIterator[bytes]:
|
170
190
|
"""
|
@@ -175,8 +195,16 @@ class AsyncWorkflowsClient:
|
|
175
195
|
id : str
|
176
196
|
The ID of the Workflow to pull from
|
177
197
|
|
198
|
+
exclude_code : typing.Optional[bool]
|
199
|
+
|
178
200
|
format : typing.Optional[WorkflowsPullRequestFormat]
|
179
201
|
|
202
|
+
include_json : typing.Optional[bool]
|
203
|
+
|
204
|
+
include_sandbox : typing.Optional[bool]
|
205
|
+
|
206
|
+
strict : typing.Optional[bool]
|
207
|
+
|
180
208
|
request_options : typing.Optional[RequestOptions]
|
181
209
|
Request-specific configuration.
|
182
210
|
|
@@ -190,7 +218,11 @@ class AsyncWorkflowsClient:
|
|
190
218
|
base_url=self._client_wrapper.get_environment().default,
|
191
219
|
method="GET",
|
192
220
|
params={
|
221
|
+
"exclude_code": exclude_code,
|
193
222
|
"format": format,
|
223
|
+
"include_json": include_json,
|
224
|
+
"include_sandbox": include_sandbox,
|
225
|
+
"strict": strict,
|
194
226
|
},
|
195
227
|
request_options=request_options,
|
196
228
|
) as _response:
|
@@ -16,9 +16,9 @@ class ChatMessagePromptBlock(UniversalBaseModel):
|
|
16
16
|
A block that represents a chat message in a prompt template.
|
17
17
|
"""
|
18
18
|
|
19
|
+
block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
|
19
20
|
state: typing.Optional[PromptBlockState] = None
|
20
21
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
21
|
-
block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
|
22
22
|
chat_role: ChatMessageRole
|
23
23
|
chat_source: typing.Optional[str] = None
|
24
24
|
chat_message_unterminated: typing.Optional[bool] = None
|
@@ -4,22 +4,41 @@ from ..core.pydantic_utilities import UniversalBaseModel
|
|
4
4
|
import typing
|
5
5
|
from .prompt_block_state import PromptBlockState
|
6
6
|
from .ephemeral_prompt_cache_config import EphemeralPromptCacheConfig
|
7
|
-
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
7
|
import pydantic
|
8
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
9
9
|
|
10
10
|
|
11
11
|
class FunctionDefinition(UniversalBaseModel):
|
12
12
|
"""
|
13
|
-
|
13
|
+
The definition of a Function (aka "Tool Call") that a Prompt/Model has access to.
|
14
14
|
"""
|
15
15
|
|
16
16
|
state: typing.Optional[PromptBlockState] = None
|
17
17
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
name: typing.Optional[str] = None
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
18
|
+
name: typing.Optional[str] = pydantic.Field(default=None)
|
19
|
+
"""
|
20
|
+
The name identifying the function.
|
21
|
+
"""
|
22
|
+
|
23
|
+
description: typing.Optional[str] = pydantic.Field(default=None)
|
24
|
+
"""
|
25
|
+
A description to help guide the model when to invoke this function.
|
26
|
+
"""
|
27
|
+
|
28
|
+
parameters: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = pydantic.Field(default=None)
|
29
|
+
"""
|
30
|
+
An OpenAPI specification of parameters that are supported by this function.
|
31
|
+
"""
|
32
|
+
|
33
|
+
forced: typing.Optional[bool] = pydantic.Field(default=None)
|
34
|
+
"""
|
35
|
+
Set this option to true to force the model to return a function call of this function.
|
36
|
+
"""
|
37
|
+
|
38
|
+
strict: typing.Optional[bool] = pydantic.Field(default=None)
|
39
|
+
"""
|
40
|
+
Set this option to use strict schema decoding when available.
|
41
|
+
"""
|
23
42
|
|
24
43
|
if IS_PYDANTIC_V2:
|
25
44
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -13,9 +13,9 @@ class JinjaPromptBlock(UniversalBaseModel):
|
|
13
13
|
A block of Jinja template code that is used to generate a prompt
|
14
14
|
"""
|
15
15
|
|
16
|
+
block_type: typing.Literal["JINJA"] = "JINJA"
|
16
17
|
state: typing.Optional[PromptBlockState] = None
|
17
18
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
block_type: typing.Literal["JINJA"] = "JINJA"
|
19
19
|
template: str
|
20
20
|
|
21
21
|
if IS_PYDANTIC_V2:
|
@@ -13,9 +13,9 @@ class PlainTextPromptBlock(UniversalBaseModel):
|
|
13
13
|
A block that holds a plain text string value.
|
14
14
|
"""
|
15
15
|
|
16
|
+
block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
|
16
17
|
state: typing.Optional[PromptBlockState] = None
|
17
18
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
|
19
19
|
text: str
|
20
20
|
|
21
21
|
if IS_PYDANTIC_V2:
|
@@ -14,9 +14,9 @@ class RichTextPromptBlock(UniversalBaseModel):
|
|
14
14
|
A block that includes a combination of plain text and variable blocks.
|
15
15
|
"""
|
16
16
|
|
17
|
+
block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
|
17
18
|
state: typing.Optional[PromptBlockState] = None
|
18
19
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
19
|
-
block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
|
20
20
|
blocks: typing.List[RichTextChildBlock]
|
21
21
|
|
22
22
|
if IS_PYDANTIC_V2:
|
@@ -13,9 +13,9 @@ class VariablePromptBlock(UniversalBaseModel):
|
|
13
13
|
A block that represents a variable in a prompt template.
|
14
14
|
"""
|
15
15
|
|
16
|
+
block_type: typing.Literal["VARIABLE"] = "VARIABLE"
|
16
17
|
state: typing.Optional[PromptBlockState] = None
|
17
18
|
cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
|
18
|
-
block_type: typing.Literal["VARIABLE"] = "VARIABLE"
|
19
19
|
input_variable: str
|
20
20
|
|
21
21
|
if IS_PYDANTIC_V2:
|
vellum/plugins/vellum_mypy.py
CHANGED
@@ -172,41 +172,39 @@ class VellumMypyPlugin(Plugin):
|
|
172
172
|
|
173
173
|
def _dynamic_output_node_class_hook(self, ctx: ClassDefContext, attribute_name: str) -> None:
|
174
174
|
"""
|
175
|
-
We use this hook to properly annotate the Outputs class for Templating Node
|
176
|
-
|
175
|
+
We use this hook to properly annotate the Outputs class for Templating Node and
|
176
|
+
Code Execution Node using the resolved type
|
177
|
+
of the TemplatingNode's and CodeExecutionNode's class _OutputType generic.
|
177
178
|
"""
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
if not templating_node_bases:
|
179
|
+
node_info = ctx.cls.info
|
180
|
+
node_bases = ctx.cls.info.bases
|
181
|
+
if not node_bases:
|
182
182
|
return
|
183
|
-
if not isinstance(
|
183
|
+
if not isinstance(node_bases[0], Instance):
|
184
184
|
return
|
185
185
|
|
186
|
-
|
187
|
-
|
188
|
-
if not _is_subclass(base_templating_node, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
|
189
|
-
return
|
186
|
+
base_args = node_bases[0].args
|
187
|
+
base_node = node_bases[0].type
|
190
188
|
|
191
|
-
if len(
|
189
|
+
if len(base_args) != 2:
|
192
190
|
return
|
193
191
|
|
194
|
-
|
195
|
-
if isinstance(
|
196
|
-
|
192
|
+
base_node_resolved_type = base_args[1]
|
193
|
+
if isinstance(base_node_resolved_type, AnyType):
|
194
|
+
base_node_resolved_type = ctx.api.named_type("builtins.str")
|
197
195
|
|
198
|
-
|
199
|
-
if not
|
196
|
+
base_node_outputs = base_node.names.get("Outputs")
|
197
|
+
if not base_node_outputs:
|
200
198
|
return
|
201
199
|
|
202
|
-
|
203
|
-
if not
|
204
|
-
|
205
|
-
new_outputs_sym =
|
200
|
+
current_node_outputs = node_info.names.get("Outputs")
|
201
|
+
if not current_node_outputs:
|
202
|
+
node_info.names["Outputs"] = base_node_outputs.copy()
|
203
|
+
new_outputs_sym = node_info.names["Outputs"].node
|
206
204
|
if isinstance(new_outputs_sym, TypeInfo):
|
207
205
|
result_sym = new_outputs_sym.names[attribute_name].node
|
208
206
|
if isinstance(result_sym, Var):
|
209
|
-
result_sym.type =
|
207
|
+
result_sym.type = base_node_resolved_type
|
210
208
|
|
211
209
|
def _base_node_class_hook(self, ctx: ClassDefContext) -> None:
|
212
210
|
"""
|
@@ -233,7 +231,6 @@ class VellumMypyPlugin(Plugin):
|
|
233
231
|
type_ = sym.node.type
|
234
232
|
if not type_:
|
235
233
|
continue
|
236
|
-
|
237
234
|
sym.node.type = self._get_resolvable_type(
|
238
235
|
lambda fullname, types: ctx.api.named_type(fullname, types), type_
|
239
236
|
)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
from uuid import uuid4
|
2
3
|
from typing import ClassVar, Generic, Iterator, List, Optional, Tuple, cast
|
3
4
|
|
@@ -18,6 +19,7 @@ from vellum.client import RequestOptions
|
|
18
19
|
from vellum.workflows.constants import OMIT
|
19
20
|
from vellum.workflows.context import get_parent_context
|
20
21
|
from vellum.workflows.errors import WorkflowErrorCode
|
22
|
+
from vellum.workflows.events.types import default_serializer
|
21
23
|
from vellum.workflows.exceptions import NodeException
|
22
24
|
from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
|
23
25
|
from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
|
@@ -108,19 +110,30 @@ class BaseInlinePromptNode(BasePromptNode, Generic[StateType]):
|
|
108
110
|
value=cast(List[ChatMessage], input_value),
|
109
111
|
)
|
110
112
|
)
|
111
|
-
|
112
|
-
|
113
|
-
|
113
|
+
else:
|
114
|
+
try:
|
115
|
+
input_value = default_serializer(input_value)
|
116
|
+
except json.JSONDecodeError as e:
|
117
|
+
raise NodeException(
|
118
|
+
message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
|
119
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
120
|
+
)
|
121
|
+
|
122
|
+
input_variables.append(
|
123
|
+
VellumVariable(
|
124
|
+
# TODO: Determine whether or not we actually need an id here and if we do,
|
125
|
+
# figure out how to maintain stable id references.
|
126
|
+
# https://app.shortcut.com/vellum/story/4080
|
127
|
+
id=str(uuid4()),
|
128
|
+
key=input_name,
|
129
|
+
type="JSON",
|
130
|
+
)
|
131
|
+
)
|
114
132
|
input_values.append(
|
115
133
|
PromptRequestJsonInput(
|
116
134
|
key=input_name,
|
117
135
|
value=input_value,
|
118
136
|
)
|
119
137
|
)
|
120
|
-
else:
|
121
|
-
raise NodeException(
|
122
|
-
message=f"Unrecognized input type for input '{input_name}': {input_value.__class__}",
|
123
|
-
code=WorkflowErrorCode.INVALID_INPUTS,
|
124
|
-
)
|
125
138
|
|
126
139
|
return input_variables, input_values
|
File without changes
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from uuid import uuid4
|
3
|
+
from typing import Any, Iterator, List
|
4
|
+
|
5
|
+
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
6
|
+
from vellum.client.types.execute_prompt_event import ExecutePromptEvent
|
7
|
+
from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
|
8
|
+
from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
|
9
|
+
from vellum.client.types.prompt_output import PromptOutput
|
10
|
+
from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
|
11
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
12
|
+
from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
|
13
|
+
|
14
|
+
|
15
|
+
def test_inline_prompt_node__json_inputs(vellum_adhoc_prompt_client):
|
16
|
+
# GIVEN a prompt node with various inputs
|
17
|
+
@dataclass
|
18
|
+
class MyDataClass:
|
19
|
+
hello: str
|
20
|
+
|
21
|
+
class MyPydantic(UniversalBaseModel):
|
22
|
+
example: str
|
23
|
+
|
24
|
+
class MyNode(InlinePromptNode):
|
25
|
+
ml_model = "gpt-4o"
|
26
|
+
blocks = []
|
27
|
+
prompt_inputs = {
|
28
|
+
"a_dict": {"foo": "bar"},
|
29
|
+
"a_list": [1, 2, 3],
|
30
|
+
"a_dataclass": MyDataClass(hello="world"),
|
31
|
+
"a_pydantic": MyPydantic(example="example"),
|
32
|
+
}
|
33
|
+
|
34
|
+
# AND a known response from invoking an inline prompt
|
35
|
+
expected_outputs: List[PromptOutput] = [
|
36
|
+
StringVellumValue(value="Test"),
|
37
|
+
]
|
38
|
+
|
39
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
40
|
+
execution_id = str(uuid4())
|
41
|
+
events: List[ExecutePromptEvent] = [
|
42
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
43
|
+
FulfilledExecutePromptEvent(
|
44
|
+
execution_id=execution_id,
|
45
|
+
outputs=expected_outputs,
|
46
|
+
),
|
47
|
+
]
|
48
|
+
yield from events
|
49
|
+
|
50
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
51
|
+
|
52
|
+
# WHEN the node is run
|
53
|
+
list(MyNode().run())
|
54
|
+
|
55
|
+
# THEN the prompt is executed with the correct inputs
|
56
|
+
mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
|
57
|
+
assert mock_api.call_count == 1
|
58
|
+
assert mock_api.call_args.kwargs["input_values"] == [
|
59
|
+
PromptRequestJsonInput(key="a_dict", type="JSON", value={"foo": "bar"}),
|
60
|
+
PromptRequestJsonInput(key="a_list", type="JSON", value=[1, 2, 3]),
|
61
|
+
PromptRequestJsonInput(key="a_dataclass", type="JSON", value={"hello": "world"}),
|
62
|
+
PromptRequestJsonInput(key="a_pydantic", type="JSON", value={"example": "example"}),
|
63
|
+
]
|
64
|
+
assert len(mock_api.call_args.kwargs["input_variables"]) == 4
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import Generic, Sequence, Type
|
2
|
+
|
3
|
+
import dotenv
|
4
|
+
|
5
|
+
from vellum.workflows.events.workflow import WorkflowEventStream
|
6
|
+
from vellum.workflows.inputs.base import BaseInputs
|
7
|
+
from vellum.workflows.logging import load_logger
|
8
|
+
from vellum.workflows.types.generics import WorkflowType
|
9
|
+
from vellum.workflows.workflows.event_filters import root_workflow_event_filter
|
10
|
+
|
11
|
+
|
12
|
+
class SandboxRunner(Generic[WorkflowType]):
|
13
|
+
def __init__(self, workflow: Type[WorkflowType], inputs: Sequence[BaseInputs]):
|
14
|
+
if not inputs:
|
15
|
+
raise ValueError("Inputs are required to have at least one defined inputs")
|
16
|
+
|
17
|
+
self._workflow = workflow
|
18
|
+
self._inputs = inputs
|
19
|
+
|
20
|
+
dotenv.load_dotenv()
|
21
|
+
self._logger = load_logger()
|
22
|
+
|
23
|
+
def run(self, index: int = 0):
|
24
|
+
if index < 0:
|
25
|
+
self._logger.warning("Index is less than 0, running first input")
|
26
|
+
index = 0
|
27
|
+
elif index >= len(self._inputs):
|
28
|
+
self._logger.warning("Index is greater than the number of provided inputs, running last input")
|
29
|
+
index = len(self._inputs) - 1
|
30
|
+
|
31
|
+
selected_inputs = self._inputs[index]
|
32
|
+
|
33
|
+
workflow = self._workflow()
|
34
|
+
events = workflow.stream(
|
35
|
+
inputs=selected_inputs,
|
36
|
+
event_filter=root_workflow_event_filter,
|
37
|
+
)
|
38
|
+
|
39
|
+
self._process_events(events)
|
40
|
+
|
41
|
+
def _process_events(self, events: WorkflowEventStream):
|
42
|
+
for event in events:
|
43
|
+
if event.name == "workflow.execution.fulfilled":
|
44
|
+
self._logger.info("Workflow fulfilled!")
|
45
|
+
for output_reference, value in event.outputs:
|
46
|
+
self._logger.info("----------------------------------")
|
47
|
+
self._logger.info(f"{output_reference.name}: {value}")
|
48
|
+
elif event.name == "node.execution.initiated":
|
49
|
+
self._logger.info(f"Just started Node: {event.node_definition.__name__}")
|
50
|
+
elif event.name == "node.execution.fulfilled":
|
51
|
+
self._logger.info(f"Just finished Node: {event.node_definition.__name__}")
|
File without changes
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import pytest
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from vellum.workflows.inputs.base import BaseInputs
|
5
|
+
from vellum.workflows.nodes.bases.base import BaseNode
|
6
|
+
from vellum.workflows.sandbox import SandboxRunner
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture
|
12
|
+
def mock_logger(mocker):
|
13
|
+
return mocker.patch("vellum.workflows.sandbox.load_logger")
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.mark.parametrize(
|
17
|
+
["run_kwargs", "expected_last_log"],
|
18
|
+
[
|
19
|
+
({}, "final_results: first"),
|
20
|
+
({"index": 1}, "final_results: second"),
|
21
|
+
({"index": -4}, "final_results: first"),
|
22
|
+
({"index": 100}, "final_results: second"),
|
23
|
+
],
|
24
|
+
ids=["default", "specific", "negative", "out_of_bounds"],
|
25
|
+
)
|
26
|
+
def test_sandbox_runner__happy_path(mock_logger, run_kwargs, expected_last_log):
|
27
|
+
# GIVEN we capture the logs to stdout
|
28
|
+
logs = []
|
29
|
+
mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)
|
30
|
+
|
31
|
+
# AND an example workflow
|
32
|
+
class Inputs(BaseInputs):
|
33
|
+
foo: str
|
34
|
+
|
35
|
+
class StartNode(BaseNode):
|
36
|
+
class Outputs(BaseNode.Outputs):
|
37
|
+
bar = Inputs.foo
|
38
|
+
|
39
|
+
class Workflow(BaseWorkflow[Inputs, BaseState]):
|
40
|
+
graph = StartNode
|
41
|
+
|
42
|
+
class Outputs(BaseWorkflow.Outputs):
|
43
|
+
final_results = StartNode.Outputs.bar
|
44
|
+
|
45
|
+
# AND a dataset for this workflow
|
46
|
+
inputs: List[Inputs] = [
|
47
|
+
Inputs(foo="first"),
|
48
|
+
Inputs(foo="second"),
|
49
|
+
]
|
50
|
+
|
51
|
+
# WHEN we run the sandbox
|
52
|
+
runner = SandboxRunner(workflow=Workflow, inputs=inputs)
|
53
|
+
runner.run(**run_kwargs)
|
54
|
+
|
55
|
+
# THEN we see the logs
|
56
|
+
assert logs == [
|
57
|
+
"Just started Node: StartNode",
|
58
|
+
"Just finished Node: StartNode",
|
59
|
+
"Workflow fulfilled!",
|
60
|
+
"----------------------------------",
|
61
|
+
expected_last_log,
|
62
|
+
]
|
vellum/workflows/types/core.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from enum import Enum
|
2
2
|
from typing import ( # type: ignore[attr-defined]
|
3
|
+
Any,
|
3
4
|
Dict,
|
4
5
|
List,
|
5
6
|
Union,
|
@@ -8,22 +9,6 @@ from typing import ( # type: ignore[attr-defined]
|
|
8
9
|
_UnionGenericAlias,
|
9
10
|
)
|
10
11
|
|
11
|
-
from vellum import (
|
12
|
-
ChatMessage,
|
13
|
-
FunctionCall,
|
14
|
-
FunctionCallRequest,
|
15
|
-
SearchResult,
|
16
|
-
SearchResultRequest,
|
17
|
-
VellumAudio,
|
18
|
-
VellumAudioRequest,
|
19
|
-
VellumError,
|
20
|
-
VellumErrorRequest,
|
21
|
-
VellumImage,
|
22
|
-
VellumImageRequest,
|
23
|
-
VellumValue,
|
24
|
-
VellumValueRequest,
|
25
|
-
)
|
26
|
-
|
27
12
|
JsonArray = List["Json"]
|
28
13
|
JsonObject = Dict[str, "Json"]
|
29
14
|
Json = Union[None, bool, int, float, str, JsonArray, JsonObject]
|
@@ -42,42 +27,7 @@ class VellumSecret:
|
|
42
27
|
self.name = name
|
43
28
|
|
44
29
|
|
45
|
-
|
46
|
-
# String inputs
|
47
|
-
str,
|
48
|
-
# Chat history inputs
|
49
|
-
List[ChatMessage],
|
50
|
-
List[ChatMessage],
|
51
|
-
# Search results inputs
|
52
|
-
List[SearchResultRequest],
|
53
|
-
List[SearchResult],
|
54
|
-
# JSON inputs
|
55
|
-
Json,
|
56
|
-
# Number inputs
|
57
|
-
float,
|
58
|
-
# Function Call Inputs
|
59
|
-
FunctionCall,
|
60
|
-
FunctionCallRequest,
|
61
|
-
# Error Inputs
|
62
|
-
VellumError,
|
63
|
-
VellumErrorRequest,
|
64
|
-
# Array Inputs
|
65
|
-
List[VellumValueRequest],
|
66
|
-
List[VellumValue],
|
67
|
-
# Image Inputs
|
68
|
-
VellumImage,
|
69
|
-
VellumImageRequest,
|
70
|
-
# Audio Inputs
|
71
|
-
VellumAudio,
|
72
|
-
VellumAudioRequest,
|
73
|
-
# Vellum Secrets
|
74
|
-
VellumSecret,
|
75
|
-
]
|
76
|
-
|
77
|
-
EntityInputsInterface = Dict[
|
78
|
-
str,
|
79
|
-
VellumValuePrimitive,
|
80
|
-
]
|
30
|
+
EntityInputsInterface = Dict[str, Any]
|
81
31
|
|
82
32
|
|
83
33
|
class MergeBehavior(Enum):
|
@@ -1,6 +1,9 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Union, get_args, get_origin
|
3
|
+
from typing import Any, Callable, Optional, Union, get_args, get_origin
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
from pydantic_core import PydanticUndefined
|
4
7
|
|
5
8
|
from vellum.client.types.function_definition import FunctionDefinition
|
6
9
|
|
@@ -16,7 +19,10 @@ type_map = {
|
|
16
19
|
}
|
17
20
|
|
18
21
|
|
19
|
-
def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
|
22
|
+
def _compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
|
23
|
+
if annotation is None:
|
24
|
+
return {"type": "null"}
|
25
|
+
|
20
26
|
if get_origin(annotation) is Union:
|
21
27
|
return {"anyOf": [_compile_annotation(a, defs) for a in get_args(annotation)]}
|
22
28
|
|
@@ -37,13 +43,44 @@ def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
|
|
37
43
|
if field.default is dataclasses.MISSING:
|
38
44
|
required.append(field.name)
|
39
45
|
else:
|
40
|
-
properties[field.name]["default"] = field.default
|
46
|
+
properties[field.name]["default"] = _compile_default_value(field.default)
|
47
|
+
defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
|
48
|
+
return {"$ref": f"#/$defs/{annotation.__name__}"}
|
49
|
+
|
50
|
+
if issubclass(annotation, BaseModel):
|
51
|
+
if annotation.__name__ not in defs:
|
52
|
+
properties = {}
|
53
|
+
required = []
|
54
|
+
for field_name, field in annotation.model_fields.items():
|
55
|
+
# Mypy is incorrect here, the `annotation` attribute is defined on `FieldInfo`
|
56
|
+
field_annotation = field.annotation # type: ignore[attr-defined]
|
57
|
+
properties[field_name] = _compile_annotation(field_annotation, defs)
|
58
|
+
if field.default is PydanticUndefined:
|
59
|
+
required.append(field_name)
|
60
|
+
else:
|
61
|
+
properties[field_name]["default"] = _compile_default_value(field.default)
|
41
62
|
defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
|
63
|
+
|
42
64
|
return {"$ref": f"#/$defs/{annotation.__name__}"}
|
43
65
|
|
44
66
|
return {"type": type_map[annotation]}
|
45
67
|
|
46
68
|
|
69
|
+
def _compile_default_value(default: Any) -> Any:
|
70
|
+
if dataclasses.is_dataclass(default):
|
71
|
+
return {
|
72
|
+
field.name: _compile_default_value(getattr(default, field.name)) for field in dataclasses.fields(default)
|
73
|
+
}
|
74
|
+
|
75
|
+
if isinstance(default, BaseModel):
|
76
|
+
return {
|
77
|
+
field_name: _compile_default_value(getattr(default, field_name))
|
78
|
+
for field_name in default.model_fields.keys()
|
79
|
+
}
|
80
|
+
|
81
|
+
return default
|
82
|
+
|
83
|
+
|
47
84
|
def compile_function_definition(function: Callable) -> FunctionDefinition:
|
48
85
|
"""
|
49
86
|
Converts a Python function into our Vellum-native FunctionDefinition type.
|
@@ -62,7 +99,7 @@ def compile_function_definition(function: Callable) -> FunctionDefinition:
|
|
62
99
|
if param.default is inspect.Parameter.empty:
|
63
100
|
required.append(param.name)
|
64
101
|
else:
|
65
|
-
properties[param.name]["default"] = param.default
|
102
|
+
properties[param.name]["default"] = _compile_default_value(param.default)
|
66
103
|
|
67
104
|
parameters = {"type": "object", "properties": properties, "required": required}
|
68
105
|
if defs:
|