vellum-ai 0.12.3__py3-none-any.whl → 0.12.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/resources/workflows/client.py +32 -0
  3. vellum/client/types/chat_message_prompt_block.py +1 -1
  4. vellum/client/types/function_definition.py +26 -7
  5. vellum/client/types/jinja_prompt_block.py +1 -1
  6. vellum/client/types/plain_text_prompt_block.py +1 -1
  7. vellum/client/types/rich_text_prompt_block.py +1 -1
  8. vellum/client/types/variable_prompt_block.py +1 -1
  9. vellum/plugins/vellum_mypy.py +20 -23
  10. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +21 -8
  11. vellum/workflows/nodes/displayable/inline_prompt_node/tests/__init__.py +0 -0
  12. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +64 -0
  13. vellum/workflows/sandbox.py +51 -0
  14. vellum/workflows/tests/__init__.py +0 -0
  15. vellum/workflows/tests/test_sandbox.py +62 -0
  16. vellum/workflows/types/core.py +2 -52
  17. vellum/workflows/utils/functions.py +41 -4
  18. vellum/workflows/utils/tests/test_functions.py +93 -0
  19. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/METADATA +1 -1
  20. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/RECORD +33 -28
  21. vellum_cli/__init__.py +14 -0
  22. vellum_cli/pull.py +16 -2
  23. vellum_cli/tests/test_pull.py +45 -0
  24. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +2 -2
  25. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -3
  26. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +18 -18
  27. vellum_ee/workflows/display/nodes/vellum/search_node.py +19 -15
  28. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  29. vellum_ee/workflows/display/nodes/vellum/utils.py +4 -4
  30. vellum_ee/workflows/display/utils/vellum.py +1 -2
  31. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/LICENSE +0 -0
  32. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/WHEEL +0 -0
  33. {vellum_ai-0.12.3.dist-info → vellum_ai-0.12.5.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.12.3",
21
+ "X-Fern-SDK-Version": "0.12.5",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -27,7 +27,11 @@ class WorkflowsClient:
27
27
  self,
28
28
  id: str,
29
29
  *,
30
+ exclude_code: typing.Optional[bool] = None,
30
31
  format: typing.Optional[WorkflowsPullRequestFormat] = None,
32
+ include_json: typing.Optional[bool] = None,
33
+ include_sandbox: typing.Optional[bool] = None,
34
+ strict: typing.Optional[bool] = None,
31
35
  request_options: typing.Optional[RequestOptions] = None,
32
36
  ) -> typing.Iterator[bytes]:
33
37
  """
@@ -38,8 +42,16 @@ class WorkflowsClient:
38
42
  id : str
39
43
  The ID of the Workflow to pull from
40
44
 
45
+ exclude_code : typing.Optional[bool]
46
+
41
47
  format : typing.Optional[WorkflowsPullRequestFormat]
42
48
 
49
+ include_json : typing.Optional[bool]
50
+
51
+ include_sandbox : typing.Optional[bool]
52
+
53
+ strict : typing.Optional[bool]
54
+
43
55
  request_options : typing.Optional[RequestOptions]
44
56
  Request-specific configuration.
45
57
 
@@ -53,7 +65,11 @@ class WorkflowsClient:
53
65
  base_url=self._client_wrapper.get_environment().default,
54
66
  method="GET",
55
67
  params={
68
+ "exclude_code": exclude_code,
56
69
  "format": format,
70
+ "include_json": include_json,
71
+ "include_sandbox": include_sandbox,
72
+ "strict": strict,
57
73
  },
58
74
  request_options=request_options,
59
75
  ) as _response:
@@ -164,7 +180,11 @@ class AsyncWorkflowsClient:
164
180
  self,
165
181
  id: str,
166
182
  *,
183
+ exclude_code: typing.Optional[bool] = None,
167
184
  format: typing.Optional[WorkflowsPullRequestFormat] = None,
185
+ include_json: typing.Optional[bool] = None,
186
+ include_sandbox: typing.Optional[bool] = None,
187
+ strict: typing.Optional[bool] = None,
168
188
  request_options: typing.Optional[RequestOptions] = None,
169
189
  ) -> typing.AsyncIterator[bytes]:
170
190
  """
@@ -175,8 +195,16 @@ class AsyncWorkflowsClient:
175
195
  id : str
176
196
  The ID of the Workflow to pull from
177
197
 
198
+ exclude_code : typing.Optional[bool]
199
+
178
200
  format : typing.Optional[WorkflowsPullRequestFormat]
179
201
 
202
+ include_json : typing.Optional[bool]
203
+
204
+ include_sandbox : typing.Optional[bool]
205
+
206
+ strict : typing.Optional[bool]
207
+
180
208
  request_options : typing.Optional[RequestOptions]
181
209
  Request-specific configuration.
182
210
 
@@ -190,7 +218,11 @@ class AsyncWorkflowsClient:
190
218
  base_url=self._client_wrapper.get_environment().default,
191
219
  method="GET",
192
220
  params={
221
+ "exclude_code": exclude_code,
193
222
  "format": format,
223
+ "include_json": include_json,
224
+ "include_sandbox": include_sandbox,
225
+ "strict": strict,
194
226
  },
195
227
  request_options=request_options,
196
228
  ) as _response:
@@ -16,9 +16,9 @@ class ChatMessagePromptBlock(UniversalBaseModel):
16
16
  A block that represents a chat message in a prompt template.
17
17
  """
18
18
 
19
+ block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
19
20
  state: typing.Optional[PromptBlockState] = None
20
21
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
21
- block_type: typing.Literal["CHAT_MESSAGE"] = "CHAT_MESSAGE"
22
22
  chat_role: ChatMessageRole
23
23
  chat_source: typing.Optional[str] = None
24
24
  chat_message_unterminated: typing.Optional[bool] = None
@@ -4,22 +4,41 @@ from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .prompt_block_state import PromptBlockState
6
6
  from .ephemeral_prompt_cache_config import EphemeralPromptCacheConfig
7
- from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
7
  import pydantic
8
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
9
9
 
10
10
 
11
11
  class FunctionDefinition(UniversalBaseModel):
12
12
  """
13
- A block that represents a function definition in a prompt template.
13
+ The definition of a Function (aka "Tool Call") that a Prompt/Model has access to.
14
14
  """
15
15
 
16
16
  state: typing.Optional[PromptBlockState] = None
17
17
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- name: typing.Optional[str] = None
19
- description: typing.Optional[str] = None
20
- parameters: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = None
21
- forced: typing.Optional[bool] = None
22
- strict: typing.Optional[bool] = None
18
+ name: typing.Optional[str] = pydantic.Field(default=None)
19
+ """
20
+ The name identifying the function.
21
+ """
22
+
23
+ description: typing.Optional[str] = pydantic.Field(default=None)
24
+ """
25
+ A description to help guide the model when to invoke this function.
26
+ """
27
+
28
+ parameters: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = pydantic.Field(default=None)
29
+ """
30
+ An OpenAPI specification of parameters that are supported by this function.
31
+ """
32
+
33
+ forced: typing.Optional[bool] = pydantic.Field(default=None)
34
+ """
35
+ Set this option to true to force the model to return a function call of this function.
36
+ """
37
+
38
+ strict: typing.Optional[bool] = pydantic.Field(default=None)
39
+ """
40
+ Set this option to use strict schema decoding when available.
41
+ """
23
42
 
24
43
  if IS_PYDANTIC_V2:
25
44
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -13,9 +13,9 @@ class JinjaPromptBlock(UniversalBaseModel):
13
13
  A block of Jinja template code that is used to generate a prompt
14
14
  """
15
15
 
16
+ block_type: typing.Literal["JINJA"] = "JINJA"
16
17
  state: typing.Optional[PromptBlockState] = None
17
18
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- block_type: typing.Literal["JINJA"] = "JINJA"
19
19
  template: str
20
20
 
21
21
  if IS_PYDANTIC_V2:
@@ -13,9 +13,9 @@ class PlainTextPromptBlock(UniversalBaseModel):
13
13
  A block that holds a plain text string value.
14
14
  """
15
15
 
16
+ block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
16
17
  state: typing.Optional[PromptBlockState] = None
17
18
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- block_type: typing.Literal["PLAIN_TEXT"] = "PLAIN_TEXT"
19
19
  text: str
20
20
 
21
21
  if IS_PYDANTIC_V2:
@@ -14,9 +14,9 @@ class RichTextPromptBlock(UniversalBaseModel):
14
14
  A block that includes a combination of plain text and variable blocks.
15
15
  """
16
16
 
17
+ block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
17
18
  state: typing.Optional[PromptBlockState] = None
18
19
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
19
- block_type: typing.Literal["RICH_TEXT"] = "RICH_TEXT"
20
20
  blocks: typing.List[RichTextChildBlock]
21
21
 
22
22
  if IS_PYDANTIC_V2:
@@ -13,9 +13,9 @@ class VariablePromptBlock(UniversalBaseModel):
13
13
  A block that represents a variable in a prompt template.
14
14
  """
15
15
 
16
+ block_type: typing.Literal["VARIABLE"] = "VARIABLE"
16
17
  state: typing.Optional[PromptBlockState] = None
17
18
  cache_config: typing.Optional[EphemeralPromptCacheConfig] = None
18
- block_type: typing.Literal["VARIABLE"] = "VARIABLE"
19
19
  input_variable: str
20
20
 
21
21
  if IS_PYDANTIC_V2:
@@ -172,41 +172,39 @@ class VellumMypyPlugin(Plugin):
172
172
 
173
173
  def _dynamic_output_node_class_hook(self, ctx: ClassDefContext, attribute_name: str) -> None:
174
174
  """
175
- We use this hook to properly annotate the Outputs class for Templating Node using the resolved type
176
- of the TemplatingNode's class _OutputType generic.
175
+ We use this hook to properly annotate the Outputs class for Templating Node and
176
+ Code Execution Node using the resolved type
177
+ of the TemplatingNode's and CodeExecutionNode's class _OutputType generic.
177
178
  """
178
-
179
- templating_node_info = ctx.cls.info
180
- templating_node_bases = ctx.cls.info.bases
181
- if not templating_node_bases:
179
+ node_info = ctx.cls.info
180
+ node_bases = ctx.cls.info.bases
181
+ if not node_bases:
182
182
  return
183
- if not isinstance(templating_node_bases[0], Instance):
183
+ if not isinstance(node_bases[0], Instance):
184
184
  return
185
185
 
186
- base_templating_args = templating_node_bases[0].args
187
- base_templating_node = templating_node_bases[0].type
188
- if not _is_subclass(base_templating_node, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
189
- return
186
+ base_args = node_bases[0].args
187
+ base_node = node_bases[0].type
190
188
 
191
- if len(base_templating_args) != 2:
189
+ if len(base_args) != 2:
192
190
  return
193
191
 
194
- base_templating_node_resolved_type = base_templating_args[1]
195
- if isinstance(base_templating_node_resolved_type, AnyType):
196
- base_templating_node_resolved_type = ctx.api.named_type("builtins.str")
192
+ base_node_resolved_type = base_args[1]
193
+ if isinstance(base_node_resolved_type, AnyType):
194
+ base_node_resolved_type = ctx.api.named_type("builtins.str")
197
195
 
198
- base_templating_node_outputs = base_templating_node.names.get("Outputs")
199
- if not base_templating_node_outputs:
196
+ base_node_outputs = base_node.names.get("Outputs")
197
+ if not base_node_outputs:
200
198
  return
201
199
 
202
- current_templating_node_outputs = templating_node_info.names.get("Outputs")
203
- if not current_templating_node_outputs:
204
- templating_node_info.names["Outputs"] = base_templating_node_outputs.copy()
205
- new_outputs_sym = templating_node_info.names["Outputs"].node
200
+ current_node_outputs = node_info.names.get("Outputs")
201
+ if not current_node_outputs:
202
+ node_info.names["Outputs"] = base_node_outputs.copy()
203
+ new_outputs_sym = node_info.names["Outputs"].node
206
204
  if isinstance(new_outputs_sym, TypeInfo):
207
205
  result_sym = new_outputs_sym.names[attribute_name].node
208
206
  if isinstance(result_sym, Var):
209
- result_sym.type = base_templating_node_resolved_type
207
+ result_sym.type = base_node_resolved_type
210
208
 
211
209
  def _base_node_class_hook(self, ctx: ClassDefContext) -> None:
212
210
  """
@@ -233,7 +231,6 @@ class VellumMypyPlugin(Plugin):
233
231
  type_ = sym.node.type
234
232
  if not type_:
235
233
  continue
236
-
237
234
  sym.node.type = self._get_resolvable_type(
238
235
  lambda fullname, types: ctx.api.named_type(fullname, types), type_
239
236
  )
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from uuid import uuid4
2
3
  from typing import ClassVar, Generic, Iterator, List, Optional, Tuple, cast
3
4
 
@@ -18,6 +19,7 @@ from vellum.client import RequestOptions
18
19
  from vellum.workflows.constants import OMIT
19
20
  from vellum.workflows.context import get_parent_context
20
21
  from vellum.workflows.errors import WorkflowErrorCode
22
+ from vellum.workflows.events.types import default_serializer
21
23
  from vellum.workflows.exceptions import NodeException
22
24
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
23
25
  from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
@@ -108,19 +110,30 @@ class BaseInlinePromptNode(BasePromptNode, Generic[StateType]):
108
110
  value=cast(List[ChatMessage], input_value),
109
111
  )
110
112
  )
111
- elif isinstance(input_value, dict):
112
- # Note: We may want to fail early here if we know that input_value is not
113
- # JSON serializable.
113
+ else:
114
+ try:
115
+ input_value = default_serializer(input_value)
116
+ except json.JSONDecodeError as e:
117
+ raise NodeException(
118
+ message=f"Failed to serialize input '{input_name}' of type '{input_value.__class__}': {e}",
119
+ code=WorkflowErrorCode.INVALID_INPUTS,
120
+ )
121
+
122
+ input_variables.append(
123
+ VellumVariable(
124
+ # TODO: Determine whether or not we actually need an id here and if we do,
125
+ # figure out how to maintain stable id references.
126
+ # https://app.shortcut.com/vellum/story/4080
127
+ id=str(uuid4()),
128
+ key=input_name,
129
+ type="JSON",
130
+ )
131
+ )
114
132
  input_values.append(
115
133
  PromptRequestJsonInput(
116
134
  key=input_name,
117
135
  value=input_value,
118
136
  )
119
137
  )
120
- else:
121
- raise NodeException(
122
- message=f"Unrecognized input type for input '{input_name}': {input_value.__class__}",
123
- code=WorkflowErrorCode.INVALID_INPUTS,
124
- )
125
138
 
126
139
  return input_variables, input_values
@@ -0,0 +1,64 @@
1
+ from dataclasses import dataclass
2
+ from uuid import uuid4
3
+ from typing import Any, Iterator, List
4
+
5
+ from vellum.client.core.pydantic_utilities import UniversalBaseModel
6
+ from vellum.client.types.execute_prompt_event import ExecutePromptEvent
7
+ from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
8
+ from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
9
+ from vellum.client.types.prompt_output import PromptOutput
10
+ from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
11
+ from vellum.client.types.string_vellum_value import StringVellumValue
12
+ from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
13
+
14
+
15
+ def test_inline_prompt_node__json_inputs(vellum_adhoc_prompt_client):
16
+ # GIVEN a prompt node with various inputs
17
+ @dataclass
18
+ class MyDataClass:
19
+ hello: str
20
+
21
+ class MyPydantic(UniversalBaseModel):
22
+ example: str
23
+
24
+ class MyNode(InlinePromptNode):
25
+ ml_model = "gpt-4o"
26
+ blocks = []
27
+ prompt_inputs = {
28
+ "a_dict": {"foo": "bar"},
29
+ "a_list": [1, 2, 3],
30
+ "a_dataclass": MyDataClass(hello="world"),
31
+ "a_pydantic": MyPydantic(example="example"),
32
+ }
33
+
34
+ # AND a known response from invoking an inline prompt
35
+ expected_outputs: List[PromptOutput] = [
36
+ StringVellumValue(value="Test"),
37
+ ]
38
+
39
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
40
+ execution_id = str(uuid4())
41
+ events: List[ExecutePromptEvent] = [
42
+ InitiatedExecutePromptEvent(execution_id=execution_id),
43
+ FulfilledExecutePromptEvent(
44
+ execution_id=execution_id,
45
+ outputs=expected_outputs,
46
+ ),
47
+ ]
48
+ yield from events
49
+
50
+ vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
51
+
52
+ # WHEN the node is run
53
+ list(MyNode().run())
54
+
55
+ # THEN the prompt is executed with the correct inputs
56
+ mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
57
+ assert mock_api.call_count == 1
58
+ assert mock_api.call_args.kwargs["input_values"] == [
59
+ PromptRequestJsonInput(key="a_dict", type="JSON", value={"foo": "bar"}),
60
+ PromptRequestJsonInput(key="a_list", type="JSON", value=[1, 2, 3]),
61
+ PromptRequestJsonInput(key="a_dataclass", type="JSON", value={"hello": "world"}),
62
+ PromptRequestJsonInput(key="a_pydantic", type="JSON", value={"example": "example"}),
63
+ ]
64
+ assert len(mock_api.call_args.kwargs["input_variables"]) == 4
@@ -0,0 +1,51 @@
1
+ from typing import Generic, Sequence, Type
2
+
3
+ import dotenv
4
+
5
+ from vellum.workflows.events.workflow import WorkflowEventStream
6
+ from vellum.workflows.inputs.base import BaseInputs
7
+ from vellum.workflows.logging import load_logger
8
+ from vellum.workflows.types.generics import WorkflowType
9
+ from vellum.workflows.workflows.event_filters import root_workflow_event_filter
10
+
11
+
12
+ class SandboxRunner(Generic[WorkflowType]):
13
+ def __init__(self, workflow: Type[WorkflowType], inputs: Sequence[BaseInputs]):
14
+ if not inputs:
15
+ raise ValueError("Inputs are required to have at least one defined inputs")
16
+
17
+ self._workflow = workflow
18
+ self._inputs = inputs
19
+
20
+ dotenv.load_dotenv()
21
+ self._logger = load_logger()
22
+
23
+ def run(self, index: int = 0):
24
+ if index < 0:
25
+ self._logger.warning("Index is less than 0, running first input")
26
+ index = 0
27
+ elif index >= len(self._inputs):
28
+ self._logger.warning("Index is greater than the number of provided inputs, running last input")
29
+ index = len(self._inputs) - 1
30
+
31
+ selected_inputs = self._inputs[index]
32
+
33
+ workflow = self._workflow()
34
+ events = workflow.stream(
35
+ inputs=selected_inputs,
36
+ event_filter=root_workflow_event_filter,
37
+ )
38
+
39
+ self._process_events(events)
40
+
41
+ def _process_events(self, events: WorkflowEventStream):
42
+ for event in events:
43
+ if event.name == "workflow.execution.fulfilled":
44
+ self._logger.info("Workflow fulfilled!")
45
+ for output_reference, value in event.outputs:
46
+ self._logger.info("----------------------------------")
47
+ self._logger.info(f"{output_reference.name}: {value}")
48
+ elif event.name == "node.execution.initiated":
49
+ self._logger.info(f"Just started Node: {event.node_definition.__name__}")
50
+ elif event.name == "node.execution.fulfilled":
51
+ self._logger.info(f"Just finished Node: {event.node_definition.__name__}")
File without changes
@@ -0,0 +1,62 @@
1
+ import pytest
2
+ from typing import List
3
+
4
+ from vellum.workflows.inputs.base import BaseInputs
5
+ from vellum.workflows.nodes.bases.base import BaseNode
6
+ from vellum.workflows.sandbox import SandboxRunner
7
+ from vellum.workflows.state.base import BaseState
8
+ from vellum.workflows.workflows.base import BaseWorkflow
9
+
10
+
11
+ @pytest.fixture
12
+ def mock_logger(mocker):
13
+ return mocker.patch("vellum.workflows.sandbox.load_logger")
14
+
15
+
16
+ @pytest.mark.parametrize(
17
+ ["run_kwargs", "expected_last_log"],
18
+ [
19
+ ({}, "final_results: first"),
20
+ ({"index": 1}, "final_results: second"),
21
+ ({"index": -4}, "final_results: first"),
22
+ ({"index": 100}, "final_results: second"),
23
+ ],
24
+ ids=["default", "specific", "negative", "out_of_bounds"],
25
+ )
26
+ def test_sandbox_runner__happy_path(mock_logger, run_kwargs, expected_last_log):
27
+ # GIVEN we capture the logs to stdout
28
+ logs = []
29
+ mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)
30
+
31
+ # AND an example workflow
32
+ class Inputs(BaseInputs):
33
+ foo: str
34
+
35
+ class StartNode(BaseNode):
36
+ class Outputs(BaseNode.Outputs):
37
+ bar = Inputs.foo
38
+
39
+ class Workflow(BaseWorkflow[Inputs, BaseState]):
40
+ graph = StartNode
41
+
42
+ class Outputs(BaseWorkflow.Outputs):
43
+ final_results = StartNode.Outputs.bar
44
+
45
+ # AND a dataset for this workflow
46
+ inputs: List[Inputs] = [
47
+ Inputs(foo="first"),
48
+ Inputs(foo="second"),
49
+ ]
50
+
51
+ # WHEN we run the sandbox
52
+ runner = SandboxRunner(workflow=Workflow, inputs=inputs)
53
+ runner.run(**run_kwargs)
54
+
55
+ # THEN we see the logs
56
+ assert logs == [
57
+ "Just started Node: StartNode",
58
+ "Just finished Node: StartNode",
59
+ "Workflow fulfilled!",
60
+ "----------------------------------",
61
+ expected_last_log,
62
+ ]
@@ -1,5 +1,6 @@
1
1
  from enum import Enum
2
2
  from typing import ( # type: ignore[attr-defined]
3
+ Any,
3
4
  Dict,
4
5
  List,
5
6
  Union,
@@ -8,22 +9,6 @@ from typing import ( # type: ignore[attr-defined]
8
9
  _UnionGenericAlias,
9
10
  )
10
11
 
11
- from vellum import (
12
- ChatMessage,
13
- FunctionCall,
14
- FunctionCallRequest,
15
- SearchResult,
16
- SearchResultRequest,
17
- VellumAudio,
18
- VellumAudioRequest,
19
- VellumError,
20
- VellumErrorRequest,
21
- VellumImage,
22
- VellumImageRequest,
23
- VellumValue,
24
- VellumValueRequest,
25
- )
26
-
27
12
  JsonArray = List["Json"]
28
13
  JsonObject = Dict[str, "Json"]
29
14
  Json = Union[None, bool, int, float, str, JsonArray, JsonObject]
@@ -42,42 +27,7 @@ class VellumSecret:
42
27
  self.name = name
43
28
 
44
29
 
45
- VellumValuePrimitive = Union[
46
- # String inputs
47
- str,
48
- # Chat history inputs
49
- List[ChatMessage],
50
- List[ChatMessage],
51
- # Search results inputs
52
- List[SearchResultRequest],
53
- List[SearchResult],
54
- # JSON inputs
55
- Json,
56
- # Number inputs
57
- float,
58
- # Function Call Inputs
59
- FunctionCall,
60
- FunctionCallRequest,
61
- # Error Inputs
62
- VellumError,
63
- VellumErrorRequest,
64
- # Array Inputs
65
- List[VellumValueRequest],
66
- List[VellumValue],
67
- # Image Inputs
68
- VellumImage,
69
- VellumImageRequest,
70
- # Audio Inputs
71
- VellumAudio,
72
- VellumAudioRequest,
73
- # Vellum Secrets
74
- VellumSecret,
75
- ]
76
-
77
- EntityInputsInterface = Dict[
78
- str,
79
- VellumValuePrimitive,
80
- ]
30
+ EntityInputsInterface = Dict[str, Any]
81
31
 
82
32
 
83
33
  class MergeBehavior(Enum):
@@ -1,6 +1,9 @@
1
1
  import dataclasses
2
2
  import inspect
3
- from typing import Any, Callable, Union, get_args, get_origin
3
+ from typing import Any, Callable, Optional, Union, get_args, get_origin
4
+
5
+ from pydantic import BaseModel
6
+ from pydantic_core import PydanticUndefined
4
7
 
5
8
  from vellum.client.types.function_definition import FunctionDefinition
6
9
 
@@ -16,7 +19,10 @@ type_map = {
16
19
  }
17
20
 
18
21
 
19
- def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
22
+ def _compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
23
+ if annotation is None:
24
+ return {"type": "null"}
25
+
20
26
  if get_origin(annotation) is Union:
21
27
  return {"anyOf": [_compile_annotation(a, defs) for a in get_args(annotation)]}
22
28
 
@@ -37,13 +43,44 @@ def _compile_annotation(annotation: Any, defs: dict[str, Any]) -> dict:
37
43
  if field.default is dataclasses.MISSING:
38
44
  required.append(field.name)
39
45
  else:
40
- properties[field.name]["default"] = field.default
46
+ properties[field.name]["default"] = _compile_default_value(field.default)
47
+ defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
48
+ return {"$ref": f"#/$defs/{annotation.__name__}"}
49
+
50
+ if issubclass(annotation, BaseModel):
51
+ if annotation.__name__ not in defs:
52
+ properties = {}
53
+ required = []
54
+ for field_name, field in annotation.model_fields.items():
55
+ # Mypy is incorrect here, the `annotation` attribute is defined on `FieldInfo`
56
+ field_annotation = field.annotation # type: ignore[attr-defined]
57
+ properties[field_name] = _compile_annotation(field_annotation, defs)
58
+ if field.default is PydanticUndefined:
59
+ required.append(field_name)
60
+ else:
61
+ properties[field_name]["default"] = _compile_default_value(field.default)
41
62
  defs[annotation.__name__] = {"type": "object", "properties": properties, "required": required}
63
+
42
64
  return {"$ref": f"#/$defs/{annotation.__name__}"}
43
65
 
44
66
  return {"type": type_map[annotation]}
45
67
 
46
68
 
69
+ def _compile_default_value(default: Any) -> Any:
70
+ if dataclasses.is_dataclass(default):
71
+ return {
72
+ field.name: _compile_default_value(getattr(default, field.name)) for field in dataclasses.fields(default)
73
+ }
74
+
75
+ if isinstance(default, BaseModel):
76
+ return {
77
+ field_name: _compile_default_value(getattr(default, field_name))
78
+ for field_name in default.model_fields.keys()
79
+ }
80
+
81
+ return default
82
+
83
+
47
84
  def compile_function_definition(function: Callable) -> FunctionDefinition:
48
85
  """
49
86
  Converts a Python function into our Vellum-native FunctionDefinition type.
@@ -62,7 +99,7 @@ def compile_function_definition(function: Callable) -> FunctionDefinition:
62
99
  if param.default is inspect.Parameter.empty:
63
100
  required.append(param.name)
64
101
  else:
65
- properties[param.name]["default"] = param.default
102
+ properties[param.name]["default"] = _compile_default_value(param.default)
66
103
 
67
104
  parameters = {"type": "object", "properties": properties, "required": required}
68
105
  if defs: