vellum-ai 0.14.69__py3-none-any.whl → 0.14.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/environment/__init__.py +2 -1
  3. vellum/workflows/environment/environment.py +10 -3
  4. vellum/workflows/nodes/displayable/code_execution_node/node.py +8 -1
  5. vellum/workflows/nodes/displayable/code_execution_node/tests/test_node.py +53 -0
  6. vellum/workflows/nodes/experimental/tool_calling_node/tests/test_node.py +77 -1
  7. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +2 -2
  8. vellum/workflows/references/environment_variable.py +11 -9
  9. {vellum_ai-0.14.69.dist-info → vellum_ai-0.14.71.dist-info}/METADATA +1 -1
  10. {vellum_ai-0.14.69.dist-info → vellum_ai-0.14.71.dist-info}/RECORD +48 -42
  11. vellum_cli/__init__.py +5 -2
  12. vellum_cli/image_push.py +24 -1
  13. vellum_cli/tests/test_image_push.py +103 -12
  14. vellum_ee/workflows/display/nodes/base_node_display.py +1 -1
  15. vellum_ee/workflows/display/nodes/utils.py +2 -2
  16. vellum_ee/workflows/display/nodes/vellum/api_node.py +2 -2
  17. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -1
  18. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +1 -1
  19. vellum_ee/workflows/display/nodes/vellum/error_node.py +1 -1
  20. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +2 -2
  21. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +1 -1
  22. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +8 -4
  23. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +9 -1
  24. vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
  25. vellum_ee/workflows/display/nodes/vellum/merge_node.py +1 -1
  26. vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -0
  27. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +1 -1
  28. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -1
  29. vellum_ee/workflows/display/nodes/vellum/search_node.py +70 -7
  30. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +1 -1
  31. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  32. vellum_ee/workflows/display/nodes/vellum/tests/test_inline_subworkflow_node.py +88 -0
  33. vellum_ee/workflows/display/nodes/vellum/tests/test_search_node.py +104 -0
  34. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +16 -0
  35. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +82 -0
  36. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +9 -1
  37. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +4 -4
  38. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_inline_workflow_serialization.py +59 -297
  39. vellum_ee/workflows/display/tests/workflow_serialization/test_workflow_input_parameterization_error.py +37 -0
  40. vellum_ee/workflows/display/utils/auto_layout.py +130 -0
  41. vellum_ee/workflows/display/utils/expressions.py +17 -1
  42. vellum_ee/workflows/display/utils/tests/__init__.py +0 -0
  43. vellum_ee/workflows/display/utils/tests/test_auto_layout.py +56 -0
  44. vellum_ee/workflows/display/workflows/base_workflow_display.py +15 -10
  45. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +41 -0
  46. {vellum_ai-0.14.69.dist-info → vellum_ai-0.14.71.dist-info}/LICENSE +0 -0
  47. {vellum_ai-0.14.69.dist-info → vellum_ai-0.14.71.dist-info}/WHEEL +0 -0
  48. {vellum_ai-0.14.69.dist-info → vellum_ai-0.14.71.dist-info}/entry_points.txt +0 -0
@@ -4,13 +4,14 @@ import os
4
4
  import shutil
5
5
  import subprocess
6
6
  import tempfile
7
- from unittest.mock import MagicMock, patch
7
+ from unittest.mock import MagicMock
8
8
  from uuid import uuid4
9
9
  from typing import Generator
10
10
 
11
11
  from click.testing import CliRunner
12
12
  from httpx import Response
13
13
 
14
+ from vellum.client.types.docker_service_token import DockerServiceToken
14
15
  from vellum_cli import main as cli_main
15
16
 
16
17
 
@@ -26,9 +27,18 @@ def mock_temp_dir() -> Generator[str, None, None]:
26
27
  shutil.rmtree(temp_dir)
27
28
 
28
29
 
29
- @patch("subprocess.run")
30
- @patch("docker.from_env")
31
- def test_image_push__self_hosted_happy_path(mock_docker_from_env, mock_run, vellum_client, monkeypatch):
30
+ @pytest.fixture
31
+ def mock_docker_from_env(mocker):
32
+ return mocker.patch("docker.from_env")
33
+
34
+
35
+ @pytest.fixture
36
+ def mock_subprocess_run(mocker):
37
+ return mocker.patch("subprocess.run")
38
+
39
+
40
+ @pytest.mark.usefixtures("vellum_client")
41
+ def test_image_push__self_hosted_happy_path(mock_docker_from_env, mock_subprocess_run, monkeypatch):
32
42
  # GIVEN a self hosted vellum api URL env var
33
43
  monkeypatch.setenv("VELLUM_API_URL", "mycompany.api.com")
34
44
  monkeypatch.setenv("VELLUM_API_KEY", "123456abcdef")
@@ -37,7 +47,7 @@ def test_image_push__self_hosted_happy_path(mock_docker_from_env, mock_run, vell
37
47
  mock_docker_client = MagicMock()
38
48
  mock_docker_from_env.return_value = mock_docker_client
39
49
 
40
- mock_run.side_effect = [
50
+ mock_subprocess_run.side_effect = [
41
51
  subprocess.CompletedProcess(
42
52
  args="", returncode=0, stdout=b'{"manifests": [{"platform": {"architecture": "amd64"}}]}'
43
53
  ),
@@ -56,10 +66,8 @@ def test_image_push__self_hosted_happy_path(mock_docker_from_env, mock_run, vell
56
66
  assert "Image successfully pushed" in result.output
57
67
 
58
68
 
59
- @patch("subprocess.run")
60
- @patch("docker.from_env")
61
69
  def test_image_push__self_hosted_happy_path__workspace_option(
62
- mock_docker_from_env, mock_run, mock_httpx_transport, mock_temp_dir
70
+ mock_docker_from_env, mock_subprocess_run, mock_httpx_transport, mock_temp_dir
63
71
  ):
64
72
  # GIVEN a workspace config with a new env for url
65
73
  with open(os.path.join(mock_temp_dir, "vellum.lock.json"), "w") as f:
@@ -90,7 +98,7 @@ def test_image_push__self_hosted_happy_path__workspace_option(
90
98
  mock_docker_client = MagicMock()
91
99
  mock_docker_from_env.return_value = mock_docker_client
92
100
 
93
- mock_run.side_effect = [
101
+ mock_subprocess_run.side_effect = [
94
102
  subprocess.CompletedProcess(
95
103
  args="", returncode=0, stdout=b'{"manifests": [{"platform": {"architecture": "amd64"}}]}'
96
104
  ),
@@ -144,9 +152,8 @@ def test_image_push__self_hosted_happy_path__workspace_option(
144
152
  assert str(request.url) == "https://api.vellum.mycompany.ai/v1/container-images/push"
145
153
 
146
154
 
147
- @patch("subprocess.run")
148
- @patch("docker.from_env")
149
- def test_image_push__self_hosted_blocks_repo(mock_docker_from_env, mock_run, vellum_client, monkeypatch):
155
+ @pytest.mark.usefixtures("vellum_client", "mock_subprocess_run")
156
+ def test_image_push__self_hosted_blocks_repo(mock_docker_from_env, monkeypatch):
150
157
  # GIVEN a self hosted vellum api URL env var
151
158
  monkeypatch.setenv("VELLUM_API_URL", "mycompany.api.com")
152
159
 
@@ -163,3 +170,87 @@ def test_image_push__self_hosted_blocks_repo(mock_docker_from_env, mock_run, vel
163
170
 
164
171
  # AND gives the error message for self hosted installs not including the repo
165
172
  assert "For adding images to your self hosted install you must include" in result.output
173
+
174
+
175
+ def test_image_push_with_source_success(
176
+ mock_docker_from_env, mock_subprocess_run, vellum_client, monkeypatch, mock_temp_dir
177
+ ):
178
+ monkeypatch.setenv("VELLUM_API_URL", "https://api.vellum.ai")
179
+ monkeypatch.setenv("VELLUM_API_KEY", "123456abcdef")
180
+
181
+ dockerfile_path = os.path.join(mock_temp_dir, "Dockerfile")
182
+ with open(dockerfile_path, "w") as f:
183
+ f.write("FROM alpine:latest\n")
184
+
185
+ mock_docker_client = MagicMock()
186
+ mock_docker_from_env.return_value = mock_docker_client
187
+ mock_docker_client.images.push.return_value = [b'{"status": "Pushed"}']
188
+
189
+ mock_subprocess_run.side_effect = [
190
+ subprocess.CompletedProcess(args="", returncode=0, stdout=b"Build successful"),
191
+ subprocess.CompletedProcess(
192
+ args="", returncode=0, stdout=b'{"manifests": [{"platform": {"architecture": "amd64"}}]}'
193
+ ),
194
+ subprocess.CompletedProcess(args="", returncode=0, stdout=b"sha256:hellosha"),
195
+ ]
196
+
197
+ vellum_client.container_images.docker_service_token.return_value = DockerServiceToken(
198
+ access_token="345678mnopqr", organization_id="test-org", repository="myrepo.net"
199
+ )
200
+
201
+ runner = CliRunner()
202
+ result = runner.invoke(cli_main, ["image", "push", "myimage:latest", "--source", dockerfile_path])
203
+
204
+ assert result.exit_code == 0, result.output
205
+
206
+ build_call = mock_subprocess_run.call_args_list[0]
207
+ assert build_call[0][0] == [
208
+ "docker",
209
+ "buildx",
210
+ "build",
211
+ "-f",
212
+ "Dockerfile",
213
+ "--platform=linux/amd64",
214
+ "-t",
215
+ "myimage:latest",
216
+ ".",
217
+ ]
218
+ assert build_call[1]["cwd"] == mock_temp_dir
219
+
220
+ assert "Docker build completed successfully" in result.output
221
+ assert "Image successfully pushed" in result.output
222
+
223
+
224
+ @pytest.mark.usefixtures("mock_docker_from_env", "mock_subprocess_run", "vellum_client")
225
+ def test_image_push_with_source_dockerfile_not_exists(monkeypatch, mock_temp_dir):
226
+ monkeypatch.setenv("VELLUM_API_URL", "https://api.vellum.ai")
227
+ monkeypatch.setenv("VELLUM_API_KEY", "123456abcdef")
228
+
229
+ nonexistent_dockerfile = os.path.join(mock_temp_dir, "nonexistent_dockerfile")
230
+
231
+ runner = CliRunner()
232
+ result = runner.invoke(cli_main, ["image", "push", "myimage:latest", "--source", nonexistent_dockerfile])
233
+
234
+ assert result.exit_code == 1
235
+ assert "Dockerfile does not exist" in result.output
236
+
237
+
238
+ @pytest.mark.usefixtures("mock_docker_from_env", "vellum_client")
239
+ def test_image_push_with_source_build_fails(mock_subprocess_run, monkeypatch, mock_temp_dir):
240
+ monkeypatch.setenv("VELLUM_API_URL", "https://api.vellum.ai")
241
+ monkeypatch.setenv("VELLUM_API_KEY", "123456abcdef")
242
+
243
+ dockerfile_path = os.path.join(mock_temp_dir, "Dockerfile")
244
+ with open(dockerfile_path, "w") as f:
245
+ f.write("FROM alpine:latest\n")
246
+
247
+ mock_subprocess_run.side_effect = [
248
+ subprocess.CompletedProcess(args="", returncode=1, stderr=b"Build failed: missing dependency"),
249
+ ]
250
+
251
+ runner = CliRunner()
252
+ result = runner.invoke(cli_main, ["image", "push", "myimage:latest", "--source", dockerfile_path])
253
+
254
+ assert result.exit_code == 1
255
+ assert "Docker build failed" in result.output
256
+ assert "Build failed: missing dependency" in result.output
@@ -310,7 +310,7 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
310
310
 
311
311
  return self._get_node_display_uuid("target_handle_id")
312
312
 
313
- def get_target_handle_id_by_source_node_id(self, source_node_id: UUID) -> UUID:
313
+ def get_target_handle_id_by_source_node_id(self, _source_node_id: UUID) -> UUID:
314
314
  """
315
315
  In the vast majority of cases, nodes will only have a single target handle and can be retrieved independently
316
316
  of the source node. However, in rare cases (such as legacy Merge nodes), this method can be overridden to
@@ -8,11 +8,11 @@ _T = TypeVar("_T")
8
8
 
9
9
 
10
10
  @overload
11
- def raise_if_descriptor(node_attr: BaseDescriptor[_T]) -> _T: ...
11
+ def raise_if_descriptor(_node_attr: BaseDescriptor[_T]) -> _T: ...
12
12
 
13
13
 
14
14
  @overload
15
- def raise_if_descriptor(node_attr: _T) -> _T: ...
15
+ def raise_if_descriptor(_node_attr: _T) -> _T: ...
16
16
 
17
17
 
18
18
  def raise_if_descriptor(node_attr: Union[NodeReference[_T], _T]) -> Optional[_T]:
@@ -1,5 +1,5 @@
1
1
  from uuid import UUID
2
- from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar, cast
2
+ from typing import ClassVar, Dict, Generic, Optional, TypeVar, cast
3
3
 
4
4
  from vellum.workflows.nodes.displayable import APINode
5
5
  from vellum.workflows.references.output import OutputReference
@@ -32,7 +32,7 @@ class BaseAPINodeDisplay(BaseNodeDisplay[_APINodeType], Generic[_APINodeType]):
32
32
  }
33
33
 
34
34
  def serialize(
35
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs: Any
35
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
36
36
  ) -> JsonObject:
37
37
  node = self._node
38
38
  node_id = self.node_id
@@ -24,7 +24,7 @@ class BaseCodeExecutionNodeDisplay(BaseNodeDisplay[_CodeExecutionNodeType], Gene
24
24
  }
25
25
 
26
26
  def serialize(
27
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
27
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
28
28
  ) -> JsonObject:
29
29
  node = self._node
30
30
  node_id = self.node_id
@@ -45,7 +45,7 @@ class BaseConditionalNodeDisplay(BaseNodeDisplay[_ConditionalNodeType], Generic[
45
45
  rule_ids: ClassVar[List[RuleIdMap]]
46
46
  condition_ids: ClassVar[list[ConditionId]]
47
47
 
48
- def serialize(self, display_context: WorkflowDisplayContext, **kwargs: Any) -> JsonObject:
48
+ def serialize(self, display_context: WorkflowDisplayContext, **_kwargs) -> JsonObject:
49
49
  node = self._node
50
50
  node_id = self.node_id
51
51
 
@@ -20,7 +20,7 @@ class BaseErrorNodeDisplay(BaseNodeDisplay[_ErrorNodeType], Generic[_ErrorNodeTy
20
20
 
21
21
  __serializable_inputs__ = {ErrorNode.error}
22
22
 
23
- def serialize(self, display_context: WorkflowDisplayContext, **kwargs) -> JsonObject:
23
+ def serialize(self, display_context: WorkflowDisplayContext, **_kwargs) -> JsonObject:
24
24
  node_id = self.node_id
25
25
  error_source_input_id = self.node_input_ids_by_name.get(
26
26
  ErrorNode.error.name,
@@ -1,5 +1,5 @@
1
1
  from uuid import UUID
2
- from typing import Any, ClassVar, Generic, Optional, TypeVar
2
+ from typing import ClassVar, Generic, Optional, TypeVar
3
3
 
4
4
  from vellum.workflows.nodes.displayable.final_output_node import FinalOutputNode
5
5
  from vellum.workflows.types.core import JsonObject
@@ -19,7 +19,7 @@ NODE_INPUT_KEY = "node_input"
19
19
  class BaseFinalOutputNodeDisplay(BaseNodeDisplay[_FinalOutputNodeType], Generic[_FinalOutputNodeType]):
20
20
  output_name: ClassVar[Optional[str]] = None
21
21
 
22
- def serialize(self, display_context: WorkflowDisplayContext, **kwargs: Any) -> JsonObject:
22
+ def serialize(self, display_context: WorkflowDisplayContext, **_kwargs) -> JsonObject:
23
23
  node = self._node
24
24
  node_id = self.node_id
25
25
 
@@ -15,7 +15,7 @@ class BaseGuardrailNodeDisplay(BaseNodeDisplay[_GuardrailNodeType], Generic[_Gua
15
15
  __serializable_inputs__ = {GuardrailNode.metric_inputs}
16
16
 
17
17
  def serialize(
18
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
18
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
19
19
  ) -> JsonObject:
20
20
  node = self._node
21
21
  node_id = self.node_id
@@ -2,6 +2,7 @@ from uuid import UUID
2
2
  from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
3
3
 
4
4
  from vellum import FunctionDefinition, PromptBlock, RichTextChildBlock, VellumVariable
5
+ from vellum.workflows.descriptors.base import BaseDescriptor
5
6
  from vellum.workflows.nodes import InlinePromptNode
6
7
  from vellum.workflows.types.core import JsonObject
7
8
  from vellum.workflows.utils.functions import compile_function_definition
@@ -18,9 +19,10 @@ _InlinePromptNodeType = TypeVar("_InlinePromptNodeType", bound=InlinePromptNode)
18
19
 
19
20
 
20
21
  class BaseInlinePromptNodeDisplay(BaseNodeDisplay[_InlinePromptNodeType], Generic[_InlinePromptNodeType]):
21
- __serializable_inputs__ = {InlinePromptNode.prompt_inputs, InlinePromptNode.functions}
22
+ __serializable_inputs__ = {
23
+ InlinePromptNode.prompt_inputs,
24
+ }
22
25
  __unserializable_attributes__ = {
23
- InlinePromptNode.blocks,
24
26
  InlinePromptNode.parameters,
25
27
  InlinePromptNode.settings,
26
28
  InlinePromptNode.expand_meta,
@@ -28,7 +30,7 @@ class BaseInlinePromptNodeDisplay(BaseNodeDisplay[_InlinePromptNodeType], Generi
28
30
  }
29
31
 
30
32
  def serialize(
31
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
33
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
32
34
  ) -> JsonObject:
33
35
  node = self._node
34
36
  node_id = self.node_id
@@ -45,7 +47,9 @@ class BaseInlinePromptNodeDisplay(BaseNodeDisplay[_InlinePromptNodeType], Generi
45
47
  ml_model = str(raise_if_descriptor(node.ml_model))
46
48
 
47
49
  blocks: list = [
48
- self._generate_prompt_block(block, input_variable_id_by_name, [i]) for i, block in enumerate(node_blocks)
50
+ self._generate_prompt_block(block, input_variable_id_by_name, [i])
51
+ for i, block in enumerate(node_blocks)
52
+ if not isinstance(block, BaseDescriptor)
49
53
  ]
50
54
 
51
55
  functions = (
@@ -2,8 +2,10 @@ from uuid import UUID
2
2
  from typing import ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar
3
3
 
4
4
  from vellum import VellumVariable
5
+ from vellum.workflows.constants import undefined
5
6
  from vellum.workflows.inputs.base import BaseInputs
6
7
  from vellum.workflows.nodes import InlineSubworkflowNode
8
+ from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
7
9
  from vellum.workflows.types.core import JsonObject
8
10
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
9
11
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
@@ -24,7 +26,7 @@ class BaseInlineSubworkflowNodeDisplay(
24
26
  __serializable_inputs__ = {InlineSubworkflowNode.subworkflow_inputs}
25
27
 
26
28
  def serialize(
27
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
29
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
28
30
  ) -> JsonObject:
29
31
  node = self._node
30
32
  node_id = self.node_id
@@ -100,6 +102,12 @@ class BaseInlineSubworkflowNodeDisplay(
100
102
  id=node_inputs_by_key[descriptor.name].id,
101
103
  key=descriptor.name,
102
104
  type=infer_vellum_variable_type(descriptor),
105
+ required=descriptor.instance is undefined,
106
+ default=(
107
+ primitive_to_vellum_value(descriptor.instance).dict()
108
+ if descriptor.instance is not undefined
109
+ else None
110
+ ),
103
111
  )
104
112
  for descriptor in subworkflow_inputs_class
105
113
  ]
@@ -17,7 +17,7 @@ class BaseMapNodeDisplay(BaseAdornmentNodeDisplay[_MapNodeType], Generic[_MapNod
17
17
  __serializable_inputs__ = {MapNode.items} # type: ignore[misc]
18
18
 
19
19
  def serialize(
20
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
20
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
21
21
  ) -> JsonObject:
22
22
  node = self._node
23
23
  node_id = self.node_id
@@ -17,7 +17,7 @@ class BaseMergeNodeDisplay(BaseNodeDisplay[_MergeNodeType], Generic[_MergeNodeTy
17
17
  super().__init__()
18
18
  self._target_handle_iterator = 0
19
19
 
20
- def serialize(self, display_context: WorkflowDisplayContext, **kwargs: Any) -> JsonObject:
20
+ def serialize(self, display_context: WorkflowDisplayContext, **_kwargs: Any) -> JsonObject:
21
21
  node = self._node
22
22
  node_id = self.node_id
23
23
 
@@ -13,6 +13,7 @@ class BaseNoteNodeDisplay(BaseNodeDisplay[_NoteNodeType], Generic[_NoteNodeType]
13
13
  style: ClassVar[Union[Dict[str, Any], None]] = None
14
14
 
15
15
  def serialize(self, display_context: WorkflowDisplayContext, **kwargs: Any) -> JsonObject:
16
+ del display_context, kwargs # Unused parameters
16
17
  node_id = self.node_id
17
18
 
18
19
  return {
@@ -15,7 +15,7 @@ class BasePromptDeploymentNodeDisplay(BaseNodeDisplay[_PromptDeploymentNodeType]
15
15
  __serializable_inputs__ = {PromptDeploymentNode.prompt_inputs}
16
16
 
17
17
  def serialize(
18
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
18
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
19
19
  ) -> JsonObject:
20
20
  node = self._node
21
21
  node_id = self.node_id
@@ -18,7 +18,7 @@ _RetryNodeType = TypeVar("_RetryNodeType", bound=RetryNode)
18
18
 
19
19
 
20
20
  class BaseRetryNodeDisplay(BaseAdornmentNodeDisplay[_RetryNodeType], Generic[_RetryNodeType]):
21
- def serialize(self, display_context: WorkflowDisplayContext, **kwargs: Any) -> JsonObject:
21
+ def serialize(self, display_context: WorkflowDisplayContext, **_kwargs: Any) -> JsonObject:
22
22
  node = self._node
23
23
  node_id = self.node_id
24
24
 
@@ -7,6 +7,7 @@ from vellum import (
7
7
  VellumValueLogicalConditionGroupRequest,
8
8
  VellumValueLogicalConditionRequest,
9
9
  )
10
+ from vellum.workflows.nodes.displayable.bases.types import MetadataLogicalCondition, MetadataLogicalConditionGroup
10
11
  from vellum.workflows.nodes.displayable.search_node import SearchNode
11
12
  from vellum.workflows.references import OutputReference
12
13
  from vellum.workflows.types.core import JsonArray, JsonObject
@@ -42,7 +43,7 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
42
43
  }
43
44
 
44
45
  def serialize(
45
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
46
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
46
47
  ) -> JsonObject:
47
48
  node = self._node
48
49
  node_id = self.node_id
@@ -88,7 +89,8 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
88
89
  node_inputs: Dict[str, NodeInput] = {}
89
90
 
90
91
  options = raise_if_descriptor(node.options)
91
- filters = options.filters if options else None
92
+ raw_filters = raise_if_descriptor(node.filters)
93
+ filters = raw_filters if raw_filters else options.filters if options else None
92
94
 
93
95
  external_id_filters = filters.external_ids if filters else None
94
96
 
@@ -104,17 +106,21 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
104
106
  raw_metadata_filters, display_context=display_context
105
107
  )
106
108
 
107
- result_merging = options.result_merging if options else None
109
+ raw_result_merging = raise_if_descriptor(node.result_merging)
110
+ result_merging = raw_result_merging if raw_result_merging else options.result_merging if options else None
108
111
  result_merging_enabled = True if result_merging and result_merging.enabled else False
109
112
 
110
113
  raw_weights = raise_if_descriptor(node.weights)
111
114
  weights = raw_weights if raw_weights is not None else options.weights if options is not None else None
112
115
 
116
+ raw_limit = raise_if_descriptor(node.limit)
117
+ limit = raw_limit if raw_limit is not None else options.limit if options is not None else None
118
+
113
119
  node_input_names_and_values = [
114
120
  ("query", node.query),
115
121
  ("document_index_id", node.document_index),
116
122
  ("weights", weights.dict() if weights else None),
117
- ("limit", options.limit if options else None),
123
+ ("limit", limit),
118
124
  ("separator", raise_if_descriptor(node.chunk_separator)),
119
125
  (
120
126
  "result_merging_enabled",
@@ -141,7 +147,12 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
141
147
 
142
148
  def _serialize_logical_expression(
143
149
  self,
144
- logical_expression: Union[VellumValueLogicalConditionGroupRequest, VellumValueLogicalConditionRequest],
150
+ logical_expression: Union[
151
+ VellumValueLogicalConditionGroupRequest,
152
+ VellumValueLogicalConditionRequest,
153
+ MetadataLogicalConditionGroup,
154
+ MetadataLogicalCondition,
155
+ ],
145
156
  display_context: WorkflowDisplayContext,
146
157
  path: List[int] = [],
147
158
  ) -> Tuple[JsonObject, List[NodeInput]]:
@@ -175,10 +186,10 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
175
186
 
176
187
  lhs_query_input_id: UUID = self.metadata_filter_input_id_by_operand_id.get(
177
188
  UUID(lhs_variable_id)
178
- ) or uuid4_from_hash(f"{self.node_id}|{hash(tuple(path))}")
189
+ ) or uuid4_from_hash(f"{self.node_id}|lhs|{hash(tuple(path))}")
179
190
  rhs_query_input_id: UUID = self.metadata_filter_input_id_by_operand_id.get(
180
191
  UUID(rhs_variable_id)
181
- ) or uuid4_from_hash(f"{self.node_id}|{hash(tuple(path))}")
192
+ ) or uuid4_from_hash(f"{self.node_id}|rhs|{hash(tuple(path))}")
182
193
 
183
194
  return (
184
195
  {
@@ -206,5 +217,57 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
206
217
  ),
207
218
  ],
208
219
  )
220
+
221
+ elif isinstance(logical_expression, MetadataLogicalConditionGroup):
222
+ conditions = []
223
+ variables = []
224
+ for idx, metadata_condition in enumerate(logical_expression.conditions):
225
+ serialized_condition, serialized_variables = self._serialize_logical_expression(
226
+ metadata_condition, display_context=display_context, path=path + [idx]
227
+ )
228
+ conditions.append(serialized_condition)
229
+ variables.extend(serialized_variables)
230
+
231
+ return (
232
+ {
233
+ "type": "LOGICAL_CONDITION_GROUP",
234
+ "combinator": logical_expression.combinator,
235
+ "conditions": conditions,
236
+ "negated": logical_expression.negated,
237
+ },
238
+ variables,
239
+ )
240
+
241
+ elif isinstance(logical_expression, MetadataLogicalCondition):
242
+ lhs_variable = logical_expression.lhs_variable
243
+ rhs_variable = logical_expression.rhs_variable
244
+
245
+ lhs_query_input_id = uuid4_from_hash(f"{self.node_id}|lhs|{hash(tuple(path))}")
246
+ rhs_query_input_id = uuid4_from_hash(f"{self.node_id}|rhs|{hash(tuple(path))}")
247
+
248
+ return (
249
+ {
250
+ "type": "LOGICAL_CONDITION",
251
+ "lhs_variable_id": str(lhs_query_input_id),
252
+ "operator": logical_expression.operator,
253
+ "rhs_variable_id": str(rhs_query_input_id),
254
+ },
255
+ [
256
+ create_node_input(
257
+ self.node_id,
258
+ f"vellum-query-builder-variable-{lhs_query_input_id}",
259
+ lhs_variable,
260
+ display_context,
261
+ input_id=lhs_query_input_id,
262
+ ),
263
+ create_node_input(
264
+ self.node_id,
265
+ f"vellum-query-builder-variable-{rhs_query_input_id}",
266
+ rhs_variable,
267
+ display_context,
268
+ input_id=rhs_query_input_id,
269
+ ),
270
+ ],
271
+ )
209
272
  else:
210
273
  raise ValueError(f"Unsupported logical expression type: {type(logical_expression)}")
@@ -17,7 +17,7 @@ class BaseSubworkflowDeploymentNodeDisplay(
17
17
  __serializable_inputs__ = {SubworkflowDeploymentNode.subworkflow_inputs}
18
18
 
19
19
  def serialize(
20
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
20
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
21
21
  ) -> JsonObject:
22
22
  node = self._node
23
23
  node_id = self.node_id
@@ -18,7 +18,7 @@ class BaseTemplatingNodeDisplay(BaseNodeDisplay[_TemplatingNodeType], Generic[_T
18
18
  __serializable_inputs__ = {TemplatingNode.inputs}
19
19
 
20
20
  def serialize(
21
- self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **kwargs
21
+ self, display_context: WorkflowDisplayContext, error_output_id: Optional[UUID] = None, **_kwargs
22
22
  ) -> JsonObject:
23
23
  node = self._node
24
24
  node_id = self.node_id
@@ -0,0 +1,88 @@
1
+ from typing import Optional
2
+
3
+ from vellum.workflows import BaseWorkflow
4
+ from vellum.workflows.inputs.base import BaseInputs
5
+ from vellum.workflows.nodes import InlineSubworkflowNode
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.state.base import BaseState
8
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
9
+
10
+
11
+ def test_serialize_node__inline_subworkflow_inputs():
12
+ # GIVEN a main workflow with inputs
13
+ class MainInputs(BaseInputs):
14
+ pass
15
+
16
+ # AND an inline subworkflow with inputs
17
+ class NestedInputs(BaseInputs):
18
+ input: str
19
+ input_with_default: str = "default"
20
+ optional_input: Optional[str] = None
21
+ optional_input_with_default: Optional[str] = "optional_default"
22
+
23
+ class NestedNode(BaseNode):
24
+ input = NestedInputs.input
25
+ input_with_default = NestedInputs.input_with_default
26
+
27
+ class Outputs(BaseNode.Outputs):
28
+ result: str
29
+
30
+ def run(self) -> Outputs:
31
+ return self.Outputs(result=f"{self.input}-{self.input_with_default}")
32
+
33
+ class NestedWorkflow(BaseWorkflow[NestedInputs, BaseState]):
34
+ graph = NestedNode
35
+
36
+ class Outputs(BaseWorkflow.Outputs):
37
+ result = NestedNode.Outputs.result
38
+
39
+ class MyInlineSubworkflowNode(InlineSubworkflowNode):
40
+ subworkflow_inputs = {
41
+ "input": "input",
42
+ "input_with_default": "input_with_default",
43
+ "optional_input": "optional_input",
44
+ "optional_input_with_default": "optional_input_with_default",
45
+ }
46
+ subworkflow = NestedWorkflow
47
+
48
+ class MainWorkflow(BaseWorkflow[MainInputs, BaseState]):
49
+ graph = MyInlineSubworkflowNode
50
+
51
+ class Outputs(BaseWorkflow.Outputs):
52
+ result = MyInlineSubworkflowNode.Outputs.result
53
+
54
+ # WHEN the workflow is serialized
55
+ workflow_display = get_workflow_display(workflow_class=MainWorkflow)
56
+ serialized_workflow: dict = workflow_display.serialize()
57
+
58
+ # THEN the inline subworkflow node should have the correct input variables
59
+ inline_subworkflow_node = next(
60
+ node
61
+ for node in serialized_workflow["workflow_raw_data"]["nodes"]
62
+ if node["id"] == str(MyInlineSubworkflowNode.__id__)
63
+ )
64
+
65
+ input_variables = inline_subworkflow_node["data"]["input_variables"]
66
+ assert len(input_variables) == 4
67
+
68
+ input_var = next(var for var in input_variables if var["key"] == "input")
69
+ assert input_var["required"] is True
70
+ assert input_var["default"] is None
71
+ assert input_var["type"] == "STRING"
72
+
73
+ input_with_default_var = next(var for var in input_variables if var["key"] == "input_with_default")
74
+ assert input_with_default_var["required"] is False
75
+ assert input_with_default_var["default"] == {"type": "STRING", "value": "default"}
76
+ assert input_with_default_var["type"] == "STRING"
77
+
78
+ optional_input_var = next(var for var in input_variables if var["key"] == "optional_input")
79
+ assert optional_input_var["required"] is False
80
+ assert optional_input_var["default"] == {"type": "JSON", "value": None}
81
+ assert optional_input_var["type"] == "STRING"
82
+
83
+ optional_input_with_default_var = next(
84
+ var for var in input_variables if var["key"] == "optional_input_with_default"
85
+ )
86
+ assert optional_input_with_default_var["required"] is False
87
+ assert optional_input_with_default_var["default"] == {"type": "STRING", "value": "optional_default"}
88
+ assert optional_input_with_default_var["type"] == "STRING"