vellum-ai 0.12.3__py3-none-any.whl → 0.12.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/resources/workflows/client.py +32 -0
  3. vellum/client/types/chat_message_prompt_block.py +1 -1
  4. vellum/client/types/function_definition.py +26 -7
  5. vellum/client/types/jinja_prompt_block.py +1 -1
  6. vellum/client/types/plain_text_prompt_block.py +1 -1
  7. vellum/client/types/rich_text_prompt_block.py +1 -1
  8. vellum/client/types/variable_prompt_block.py +1 -1
  9. vellum/plugins/vellum_mypy.py +20 -23
  10. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +21 -8
  11. vellum/workflows/nodes/displayable/inline_prompt_node/tests/__init__.py +0 -0
  12. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +64 -0
  13. vellum/workflows/sandbox.py +51 -0
  14. vellum/workflows/tests/__init__.py +0 -0
  15. vellum/workflows/tests/test_sandbox.py +62 -0
  16. vellum/workflows/types/core.py +2 -52
  17. vellum/workflows/utils/functions.py +41 -4
  18. vellum/workflows/utils/tests/test_functions.py +93 -0
  19. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/METADATA +1 -1
  20. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/RECORD +33 -28
  21. vellum_cli/__init__.py +14 -0
  22. vellum_cli/pull.py +16 -2
  23. vellum_cli/tests/test_pull.py +45 -0
  24. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +2 -2
  25. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -3
  26. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +18 -18
  27. vellum_ee/workflows/display/nodes/vellum/search_node.py +19 -15
  28. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  29. vellum_ee/workflows/display/nodes/vellum/utils.py +4 -4
  30. vellum_ee/workflows/display/utils/vellum.py +1 -2
  31. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/LICENSE +0 -0
  32. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/WHEEL +0 -0
  33. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.12.3",
21
+ "X-Fern-SDK-Version": "0.12.5",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -27,7 +27,11 @@ class WorkflowsClient:
27
27
  self,
28
28
  id: str,
29
29
  *,
30
+ exclude_code: typing.Optional[bool] = None,
30
31
  format: typing.Optional[WorkflowsPullRequestFormat] = None,
32
+ include_json: typing.Optional[bool] = None,
33
+ include_sandbox: typing.Optional[bool] = None,
34
+ strict: typing.Optional[bool] = None,
31
35
  request_options: typing.Optional[RequestOptions] = None,
32
36
  ) -> typing.Iterator[bytes]:
33
37
  """
@@ -38,8 +42,16 @@ class WorkflowsClient:
38
42
  id : str
39
43
  The ID of the Workflow to pull from
40
44
 
45
+ exclude_code : typing.Optional[bool]
46
+
41
47
  format : typing.Optional[WorkflowsPullRequestFormat]
42
48
 
49
+ include_json : typing.Optional[bool]
50
+
51
+ include_sandbox : typing.Optional[bool]
52
+
53
+ strict : typing.Optional[bool]
54
+
43
55
  request_options : typing.Optional[RequestOptions]
44
56
  Request-specific configuration.
45
57
 
@@ -53,7 +65,11 @@ class WorkflowsClient:
53
65
  base_url=self._client_wrapper.get_environment().default,
54
66
  method="GET",
55
67
  params={
68
+ "exclude_code": exclude_code,
56
69
  "format": format,
70
+ "include_json": include_json,
71
+ "include_sandbox": include_sandbox,
72
+ "strict": strict,
57
73
  },
58
74
  request_options=request_options,
59
75
  ) as _response:
@@ -164,7 +180,11 @@ class AsyncWorkflowsClient:
164
180
  self,
165
181
  id: str,
166
182
  *,
183
+ exclude_code: typing.Optional[bool] = None,
167
184
  format: typing.Optional[WorkflowsPullRequestFormat] = None,
185
+ include_json: typing.Optional[bool] = None,
186
+ include_sandbox: typing.Optional[bool] = None,
187
+ strict: typing.Optional[bool] = None,
168
188
  request_options: typing.Optional[RequestOptions] = None,
169
189
  ) -> typing.AsyncIterator[bytes]:
170
190
  """
@@ -175,8 +195,16 @@ class AsyncWorkflowsClient:
175
195
  id : str
176
196
  The ID of the Workflow to pull from
177
197
 
198
+ exclude_code : typing.Optional[bool]
199
+
178
200
  format : typing.Optional[WorkflowsPullRequestFormat]
179
201
 
202
+ include_json : typing.Optional[bool]
203
+
204
+ include_sandbox : typing.Optional[bool]
205
+
206
+ strict : typing.Optional[bool]
207
+
180
208
  request_options : typing.Optional[RequestOptions]
181
209
  Request-specific configuration.
182
210
 
@@ -190,7 +218,11 @@ class AsyncWorkflowsClient:
190
218
  base_url=self._client_wrapper.get_environment().default,
191
219
  method="GET",
192
220
  params={
221
+ "exclude_code": exclude_code,
193
222
  "format": format,
223
+ "include_json": include_json,
224
+ "include_sandbox": include_sandbox,
225
+ "strict": strict,
194
226
  },
195
227
  request_options=request_options,
196
228
  ) as _response:
@@ -16,9 +16,9 @@ class ChatMessagePromptBlock(UniversalBaseModel):
16
16
  A block that represents a chat message in a prompt template.
17
17
  """
18
18
 
19
+ block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
19
20
  state: typing.Optional[PromptBlockState] = None
20
21
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
21
- block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
22
22
  chat_role: ChatMessageRole
23
23
  chat_source: typing.Optional[str] = None
24
24
  chat_message_unterminated: typing.Optional[bool] = None
@@ -4,22 +4,41 @@ from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .prompt_block_state import PromptBlockState
6
6
  from .ephemeral_prompt_cache_config import EphemeralPromptCacheConfig
7
- from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
7
  import pydantic
8
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
9
9
 
10
10
 
11
11
  class FunctionDefinition(UniversalBaseModel):
12
12
  """
13
- A block that represents a function definition in a prompt template.
13
+ The definition of a Function (aka "Tool Call") that a Prompt/Model has access to.
14
14
  """
15
15
 
16
16
  state: typing.Optional[PromptBlockState] = None
17
17
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- name: typing.Optional[str] = None
19
- description: typing.Optional[str] = None
20
- parameters: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = None
21
- forced: typing.Optional[bool] = None
22
- strict: typing.Optional[bool] = None
18
+ name: typing.Optional[str] = pydantic.Field(default=None)
19
+ """
20
+ The name identifying the function.
21
+ """
22
+
23
+ description: typing.Optional[str] = pydantic.Field(default=None)
24
+ """
25
+ A description to help guide the model when to invoke this function.
26
+ """
27
+
28
+ parameters: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = pydantic.Field(default=None)
29
+ """
30
+ An OpenAPI specification of parameters that are supported by this function.
31
+ """
32
+
33
+ forced: typing.Optional[bool] = pydantic.Field(default=None)
34
+ """
35
+ Set this option to true to force the model to return a function call of this function.
36
+ """
37
+
38
+ strict: typing.Optional[bool] = pydantic.Field(default=None)
39
+ """
40
+ Set this option to use strict schema decoding when available.
41
+ """
23
42
 
24
43
  if IS_PYDANTIC_V2:
25
44
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -13,9 +13,9 @@ class JinjaPromptBlock(UniversalBaseModel):
13
13
  A block of Jinja template code that is used to generate a prompt
14
14
  """
15
15
 
16
+ block_type: typing.Literal["JINJA"] = "JINJA"
16
17
  state: typing.Optional[PromptBlockState] = None
17
18
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- block_type: typing.Literal["JINJA"] = "JINJA"
19
19
  template: str
20
20
 
21
21
  if IS_PYDANTIC_V2:
@@ -13,9 +13,9 @@ class PlainTextPromptBlock(UniversalBaseModel):
13
13
  A block that holds a plain text string value.
14
14
  """
15
15
 
16
+ block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
16
17
  state: typing.Optional[PromptBlockState] = None
17
18
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
19
19
  text: str
20
20
 
21
21
  if IS_PYDANTIC_V2:
@@ -14,9 +14,9 @@ class RichTextPromptBlock(UniversalBaseModel):
14
14
  A block that includes a combination of plain text and variable blocks.
15
15
  """
16
16
 
17
+ block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
17
18
  state: typing.Optional[PromptBlockState] = None
18
19
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
19
- block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
20
20
  blocks: typing.List[RichTextChildBlock]
21
21
 
22
22
  if IS_PYDANTIC_V2:
@@ -13,9 +13,9 @@ class VariablePromptBlock(UniversalBaseModel):
13
13
  A block that represents a variable in a prompt template.
14
14
  """
15
15
 
16
+ block_type: typing.Literal["VARIABLE"] = "VARIABLE"
16
17
  state: typing.Optional[PromptBlockState] = None
17
18
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- block_type: typing.Literal["VARIABLE"] = "VARIABLE"
19
19
  input_variable: str
20
20
 
21
21
  if IS_PYDANTIC_V2:
@@ -172,41 +172,39 @@ class VellumMypyPlugin(Plugin):
172
172
 
173
173
  def _dynamic_output_node_class_hook(self, ctx: ClassDefContext, attribute_name: str) -> None:
174
174
  """
175
- We use this hook to properly annotate the Outputs class for Templating Node using the resolved type
176
- of the TemplatingNode's class _OutputType generic.
175
+ We use this hook to properly annotate the Outputs class for Templating Node and
176
+ Code Execution Node using the resolved type
177
+ of the TemplatingNode's and CodeExecutionNode's class _OutputType generic.
177
178
  """
178
-
179
- templating_node_info = ctx.cls.info
180
- templating_node_bases = ctx.cls.info.bases
181
- if not templating_node_bases:
179
+ node_info = ctx.cls.info
180
+ node_bases = ctx.cls.info.bases
181
+ if not node_bases:
182
182
  return
183
- if not isinstance(templating_node_bases[0], Instance):
183
+ if not isinstance(node_bases[0], Instance):
184
184
  return
185
185
 
186
- base_templating_args = templating_node_bases[0].args
187
- base_templating_node = templating_node_bases[0].type
188
- if not _is_subclass(base_templating_node, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
189
- return
186
+ base_args = node_bases[0].args
187
+ base_node = node_bases[0].type
190
188
 
191
- if len(base_templating_args) != 2:
189
+ if len(base_args) != 2:
192
190
  return
193
191
 
194
- base_templating_node_resolved_type = base_templating_args[1]
195
- if isinstance(base_templating_node_resolved_type, AnyType):
196
- base_templating_node_resolved_type = ctx.api.named_type("builtins.str")
192
+ base_node_resolved_type = base_args[1]
193
+ if isinstance(base_node_resolved_type, AnyType):
194
+ base_node_resolved_type = ctx.api.named_type("builtins.str")
197
195
 
198
- base_templating_node_outputs = base_templating_node.names.get("Outputs")
199
- if not base_templating_node_outputs:
196
+ base_node_outputs = base_node.names.get("Outputs")
197
+ if not base_node_outputs:
200
198
  return
201
199
 
202
- current_templating_node_outputs = templating_node_info.names.get("Outputs")
203
- if not current_templating_node_outputs:
204
- templating_node_info.names["Outputs"] = base_templating_node_outputs.copy()
205
- new_outputs_sym = templating_node_info.names["Outputs"].node
200
+ current_node_outputs = node_info.names.get("Outputs")
201
+ if not current_node_outputs:
202
+ node_info.names["Outputs"] = base_node_outputs.copy()
203
+ new_outputs_sym = node_info.names["Outputs"].node
206
204
  if isinstance(new_outputs_sym, TypeInfo):
207
205
  result_sym = new_outputs_sym.names[attribute_name].node
208
206
  if isinstance(result_sym, Var):
209
- result_sym.type = base_templating_node_resolved_type
207
+ result_sym.type = base_node_resolved_type
210
208
 
211
209
  def _base_node_class_hook(self, ctx: ClassDefContext) -> None:
212
210
  """
@@ -233,7 +231,6 @@ class VellumMypyPlugin(Plugin):
233
231
  type_ = sym.node.type
234
232
  if not type_:
235
233
  continue
236
-
237
234
  sym.node.type = self._get_resolvable_type(
238
235
  lambda fullname, types: ctx.api.named_type(fullname, types), type_
239
236
  )
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from uuid import uuid4
2
3
  from typing import ClassVar, Generic, Iterator, List, Optional, Tuple, cast
3
4
 
@@ -18,6 +19,7 @@ from vellum.client import RequestOptions
18
19
  from vellum.workflows.constants import OMIT
19
20
  from vellum.workflows.context import get_parent_context
20
21
  from vellum.workflows.errors import WorkflowErrorCode
22
+ from vellum.workflows.events.types import default_serializer
21
23
  from vellum.workflows.exceptions import NodeException
22
24
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
23
25
  from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
@@ -108,19 +110,30 @@ class BaseInlinePromptNode(BasePromptNode, Generic[StateType]):
108
110
  value=cast(List[ChatMessage], input_value),
109
111
  )
110
112
  )
111
- elif isinstance(input_value, dict):
112
- # Note: We may want to fail early here if we know that input_value is not
113
- # JSON serializable.
113
+ else:
114
+ try:
115
+ input_value = default_serializer(input_value)
116
+ except json.JSONDecodeError as e:
117
+ raise NodeException(
118
+ message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
119
+ code=WorkflowErrorCode.INVALID_INPUTS,
120
+ )
121
+
122
+ input_variables.append(
123
+ VellumVariable(
124
+ # TODO: Determine whether or not we actually need an id here and if we do,
125
+ # figure out how to maintain stable id references.
126
+ # https://app.shortcut.com/vellum/story/4080
127
+ id=str(uuid4()),
128
+ key=input_name,
129
+ type="JSON",
130
+ )
131
+ )
114
132
  input_values.append(
115
133
  PromptRequestJsonInput(
116
134
  key=input_name,
117
135
  value=input_value,
118
136
  )
119
137
  )
120
- else:
121
- raise NodeException(
122
- message=f"Unrecognized input type for input '{input_name}': {input_value.__class__}",
123
- code=WorkflowErrorCode.INVALID_INPUTS,
124
- )
125
138
 
126
139
  return input_variables, input_values
@@ -0,0 +1,64 @@
1
+ from dataclasses import dataclass
2
+ from uuid import uuid4
3
+ from typing import Any, Iterator, List
4
+
5
+ from vellum.client.core.pydantic_utilities import UniversalBaseModel
6
+ from vellum.client.types.execute_prompt_event import ExecutePromptEvent
7
+ from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
8
+ from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
9
+ from vellum.client.types.prompt_output import PromptOutput
10
+ from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
11
+ from vellum.client.types.string_vellum_value import StringVellumValue
12
+ from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
13
+
14
+
15
+ def test_inline_prompt_node__json_inputs(vellum_adhoc_prompt_client):
16
+ # GIVEN a prompt node with various inputs
17
+ @dataclass
18
+ class MyDataClass:
19
+ hello: str
20
+
21
+ class MyPydantic(UniversalBaseModel):
22
+ example: str
23
+
24
+ class MyNode(InlinePromptNode):
25
+ ml_model = "gpt-4o"
26
+ blocks = []
27
+ prompt_inputs = {
28
+ "a_dict": {"foo": "bar"},
29
+ "a_list": [1, 2, 3],
30
+ "a_dataclass": MyDataClass(hello="world"),
31
+ "a_pydantic": MyPydantic(example="example"),
32
+ }
33
+
34
+ # AND a known response from invoking an inline prompt
35
+ expected_outputs: List[PromptOutput] = [
36
+ StringVellumValue(value="Test"),
37
+ ]
38
+
39
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
40
+ execution_id = str(uuid4())
41
+ events: List[ExecutePromptEvent] = [
42
+ InitiatedExecutePromptEvent(execution_id=execution_id),
43
+ FulfilledExecutePromptEvent(
44
+ execution_id=execution_id,
45
+ outputs=expected_outputs,
46
+ ),
47
+ ]
48
+ yield from events
49
+
50
+ vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
51
+
52
+ # WHEN the node is run
53
+ list(MyNode().run())
54
+
55
+ # THEN the prompt is executed with the correct inputs
56
+ mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
57
+ assert mock_api.call_count == 1
58
+ assert mock_api.call_args.kwargs["input_values"] == [
59
+ PromptRequestJsonInput(key="a_dict", type="JSON", value={"foo": "bar"}),
60
+ PromptRequestJsonInput(key="a_list", type="JSON", value=[1, 2, 3]),
61
+ PromptRequestJsonInput(key="a_dataclass", type="JSON", value={"hello": "world"}),
62
+ PromptRequestJsonInput(key="a_pydantic", type="JSON", value={"example": "example"}),
63
+ ]
64
+ assert len(mock_api.call_args.kwargs["input_variables"]) == 4
@@ -0,0 +1,51 @@
1
+ from typing import Generic, Sequence, Type
2
+
3
+ import dotenv
4
+
5
+ from vellum.workflows.events.workflow import WorkflowEventStream
6
+ from vellum.workflows.inputs.base import BaseInputs
7
+ from vellum.workflows.logging import load_logger
8
+ from vellum.workflows.types.generics import WorkflowType
9
+ from vellum.workflows.workflows.event_filters import root_workflow_event_filter
10
+
11
+
12
+ class SandboxRunner(Generic[WorkflowType]):
13
+ def __init__(self, workflow: Type[WorkflowType], inputs: Sequence[BaseInputs]):
14
+ if not inputs:
15
+ raise ValueError("Inputs are required to have at least one defined inputs")
16
+
17
+ self._workflow = workflow
18
+ self._inputs = inputs
19
+
20
+ dotenv.load_dotenv()
21
+ self._logger = load_logger()
22
+
23
+ def run(self, index: int = 0):
24
+ if index < 0:
25
+ self._logger.warning("Index is less than 0, running first input")
26
+ index = 0
27
+ elif index >= len(self._inputs):
28
+ self._logger.warning("Index is greater than the number of provided inputs, running last input")
29
+ index = len(self._inputs) - 1
30
+
31
+ selected_inputs = self._inputs[index]
32
+
33
+ workflow = self._workflow()
34
+ events = workflow.stream(
35
+ inputs=selected_inputs,
36
+ event_filter=root_workflow_event_filter,
37
+ )
38
+
39
+ self._process_events(events)
40
+
41
+ def _process_events(self, events: WorkflowEventStream):
42
+ for event in events:
43
+ if event.name == "workflow.execution.fulfilled":
44
+ self._logger.info("Workflow fulfilled!")
45
+ for output_reference, value in event.outputs:
46
+ self._logger.info("----------------------------------")
47
+ self._logger.info(f"{output_reference.name}: {value}")
48
+ elif event.name == "node.execution.initiated":
49
+ self._logger.info(f"Just started Node: {event.node_definition.__name__}")
50
+ elif event.name == "node.execution.fulfilled":
51
+ self._logger.info(f"Just finished Node: {event.node_definition.__name__}")
File without changes
@@ -0,0 +1,62 @@
1
+ import pytest
2
+ from typing import List
3
+
4
+ from vellum.workflows.inputs.base import BaseInputs
5
+ from vellum.workflows.nodes.bases.base import BaseNode
6
+ from vellum.workflows.sandbox import SandboxRunner
7
+ from vellum.workflows.state.base import BaseState
8
+ from vellum.workflows.workflows.base import BaseWorkflow
9
+
10
+
11
+ @pytest.fixture
12
+ def mock_logger(mocker):
13
+ return mocker.patch("vellum.workflows.sandbox.load_logger")
14
+
15
+
16
+ @pytest.mark.parametrize(
17
+ ["run_kwargs", "expected_last_log"],
18
+ [
19
+ ({}, "final_results: first"),
20
+ ({"index": 1}, "final_results: second"),
21
+ ({"index": -4}, "final_results: first"),
22
+ ({"index": 100}, "final_results: second"),
23
+ ],
24
+ ids=["default", "specific", "negative", "out_of_bounds"],
25
+ )
26
+ def test_sandbox_runner__happy_path(mock_logger, run_kwargs, expected_last_log):
27
+ # GIVEN we capture the logs to stdout
28
+ logs = []
29
+ mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)
30
+
31
+ # AND an example workflow
32
+ class Inputs(BaseInputs):
33
+ foo: str
34
+
35
+ class StartNode(BaseNode):
36
+ class Outputs(BaseNode.Outputs):
37
+ bar = Inputs.foo
38
+
39
+ class Workflow(BaseWorkflow[Inputs, BaseState]):
40
+ graph = StartNode
41
+
42
+ class Outputs(BaseWorkflow.Outputs):
43
+ final_results = StartNode.Outputs.bar
44
+
45
+ # AND a dataset for this workflow
46
+ inputs: List[Inputs] = [
47
+ Inputs(foo="first"),
48
+ Inputs(foo="second"),
49
+ ]
50
+
51
+ # WHEN we run the sandbox
52
+ runner = SandboxRunner(workflow=Workflow, inputs=inputs)
53
+ runner.run(**run_kwargs)
54
+
55
+ # THEN we see the logs
56
+ assert logs == [
57
+ "Just started Node: StartNode",
58
+ "Just finished Node: StartNode",
59
+ "Workflow fulfilled!",
60
+ "----------------------------------",
61
+ expected_last_log,
62
+ ]
@@ -1,5 +1,6 @@
1
1
  from enum import Enum
2
2
  from typing import ( # type: ignore[attr-defined]
3
+ Any,
3
4
  Dict,
4
5
  List,
5
6
  Union,
@@ -8,22 +9,6 @@ from typing import ( # type: ignore[attr-defined]
8
9
  _UnionGenericAlias,
9
10
  )
10
11
 
11
- from vellum import (
12
- ChatMessage,
13
- FunctionCall,
14
- FunctionCallRequest,
15
- SearchResult,
16
- SearchResultRequest,
17
- VellumAudio,
18
- VellumAudioRequest,
19
- VellumError,
20
- VellumErrorRequest,
21
- VellumImage,
22
- VellumImageRequest,
23
- VellumValue,
24
- VellumValueRequest,
25
- )
26
-
27
12
  JsonArray = List["Json"]
28
13
  JsonObject = Dict[str, "Json"]
29
14
  Json = Union[None, bool, int, float, str, JsonArray, JsonObject]
@@ -42,42 +27,7 @@ class VellumSecret:
42
27
  self.name = name
43
28
 
44
29
 
45
- VellumValuePrimitive = Union[
46
- # String inputs
47
- str,
48
- # Chat history inputs
49
- List[ChatMessage],
50
- List[ChatMessage],
51
- # Search results inputs
52
- List[SearchResultRequest],
53
- List[SearchResult],
54
- # JSON inputs
55
- Json,
56
- # Number inputs
57
- float,
58
- # Function Call Inputs
59
- FunctionCall,
60
- FunctionCallRequest,
61
- # Error Inputs
62
- VellumError,
63
- VellumErrorRequest,
64
- # Array Inputs
65
- List[VellumValueRequest],
66
- List[VellumValue],
67
- # Image Inputs
68
- VellumImage,
69
- VellumImageRequest,
70
- # Audio Inputs
71
- VellumAudio,
72
- VellumAudioRequest,
73
- # Vellum Secrets
74
- VellumSecret,
75
- ]
76
-
77
- EntityInputsInterface = Dict[
78
- str,
79
- VellumValuePrimitive,
80
- ]
30
+ EntityInputsInterface = Dict[str, Any]
81
31
 
82
32
 
83
33
  class MergeBehavior(Enum):
@@ -1,6 +1,9 @@
1
1
  import dataclasses
2
2
  import inspect
3
- from typing import Any, Callable, Union, get_args, get_origin
3
+ from typing import Any, Callable, Optional, Union, get_args, get_origin
4
+
5
+ from pydantic import BaseModel
6
+ from pydantic_core import PydanticUndefined
4
7
 
5
8
  from vellum.client.types.function_definition import FunctionDefinition
6
9
 
@@ -16,7 +19,10 @@ type_map = {
16
19
  }
17
20
 
18
21
 
19
- def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
22
+ def _compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
23
+ if annotation is None:
24
+ return {"type": "null"}
25
+
20
26
  if get_origin(annotation) is Union:
21
27
  return {"anyOf": [_compile_annotation(a, defs) for a in get_args(annotation)]}
22
28
 
@@ -37,13 +43,44 @@ def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
37
43
  if field.default is dataclasses.MISSING:
38
44
  required.append(field.name)
39
45
  else:
40
- properties[field.name]["default"] = field.default
46
+ properties[field.name]["default"] = _compile_default_value(field.default)
47
+ defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
48
+ return {"$ref": f"#/$defs/{annotation.__name__}"}
49
+
50
+ if issubclass(annotation, BaseModel):
51
+ if annotation.__name__ not in defs:
52
+ properties = {}
53
+ required = []
54
+ for field_name, field in annotation.model_fields.items():
55
+ # Mypy is incorrect here, the `annotation` attribute is defined on `FieldInfo`
56
+ field_annotation = field.annotation # type: ignore[attr-defined]
57
+ properties[field_name] = _compile_annotation(field_annotation, defs)
58
+ if field.default is PydanticUndefined:
59
+ required.append(field_name)
60
+ else:
61
+ properties[field_name]["default"] = _compile_default_value(field.default)
41
62
  defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
63
+
42
64
  return {"$ref": f"#/$defs/{annotation.__name__}"}
43
65
 
44
66
  return {"type": type_map[annotation]}
45
67
 
46
68
 
69
+ def _compile_default_value(default: Any) -> Any:
70
+ if dataclasses.is_dataclass(default):
71
+ return {
72
+ field.name: _compile_default_value(getattr(default, field.name)) for field in dataclasses.fields(default)
73
+ }
74
+
75
+ if isinstance(default, BaseModel):
76
+ return {
77
+ field_name: _compile_default_value(getattr(default, field_name))
78
+ for field_name in default.model_fields.keys()
79
+ }
80
+
81
+ return default
82
+
83
+
47
84
  def compile_function_definition(function: Callable) -> FunctionDefinition:
48
85
  """
49
86
  Converts a Python function into our Vellum-native FunctionDefinition type.
@@ -62,7 +99,7 @@ def compile_function_definition(function: Callable) -> FunctionDefinition:
62
99
  if param.default is inspect.Parameter.empty:
63
100
  required.append(param.name)
64
101
  else:
65
- properties[param.name]["default"] = param.default
102
+ properties[param.name]["default"] = _compile_default_value(param.default)
66
103
 
67
104
  parameters = {"type": "object", "properties": properties, "required": required}
68
105
  if defs: