vellum-ai 0.12.14__py3-none-any.whl → 0.12.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. vellum/__init__.py +6 -0
  2. vellum/client/__init__.py +2 -6
  3. vellum/client/core/client_wrapper.py +1 -1
  4. vellum/client/environment.py +3 -3
  5. vellum/client/resources/ad_hoc/client.py +2 -6
  6. vellum/client/resources/container_images/client.py +0 -8
  7. vellum/client/resources/metric_definitions/client.py +2 -6
  8. vellum/client/resources/workflows/client.py +8 -8
  9. vellum/client/types/__init__.py +6 -0
  10. vellum/client/types/audio_prompt_block.py +29 -0
  11. vellum/client/types/function_call_prompt_block.py +30 -0
  12. vellum/client/types/image_prompt_block.py +29 -0
  13. vellum/client/types/prompt_block.py +12 -1
  14. vellum/client/types/workflow_push_response.py +1 -0
  15. vellum/prompts/blocks/compilation.py +43 -0
  16. vellum/types/audio_prompt_block.py +3 -0
  17. vellum/types/function_call_prompt_block.py +3 -0
  18. vellum/types/image_prompt_block.py +3 -0
  19. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +10 -2
  20. vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +16 -0
  21. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/METADATA +11 -9
  22. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/RECORD +59 -48
  23. vellum_cli/__init__.py +14 -0
  24. vellum_cli/config.py +4 -0
  25. vellum_cli/pull.py +20 -5
  26. vellum_cli/push.py +33 -4
  27. vellum_cli/tests/test_pull.py +19 -1
  28. vellum_cli/tests/test_push.py +63 -0
  29. vellum_ee/workflows/display/nodes/vellum/__init__.py +2 -0
  30. vellum_ee/workflows/display/nodes/vellum/api_node.py +3 -3
  31. vellum_ee/workflows/display/nodes/vellum/base_node.py +35 -0
  32. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -2
  33. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +2 -2
  34. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +20 -6
  35. vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -0
  36. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -2
  37. vellum_ee/workflows/display/nodes/vellum/search_node.py +4 -2
  38. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  39. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +3 -3
  40. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/__init__.py +0 -0
  41. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +28 -0
  42. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_trigger_serialization.py +123 -0
  43. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +6 -46
  44. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +3 -25
  45. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +168 -0
  46. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +18 -10
  47. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +18 -10
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +3 -25
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -8
  50. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +13 -27
  51. vellum_ee/workflows/display/types.py +5 -1
  52. vellum_ee/workflows/display/utils/vellum.py +3 -3
  53. vellum_ee/workflows/display/vellum.py +4 -0
  54. vellum_ee/workflows/display/workflows/base_workflow_display.py +44 -16
  55. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +3 -0
  56. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +7 -8
  57. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/LICENSE +0 -0
  58. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/WHEEL +0 -0
  59. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,35 @@
1
+ from typing import Any, Generic, TypeVar
2
+
3
+ from vellum.workflows.nodes.bases.base import BaseNode
4
+ from vellum.workflows.types.core import JsonObject
5
+ from vellum.workflows.utils.uuids import uuid4_from_hash
6
+ from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
7
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
8
+ from vellum_ee.workflows.display.vellum import GenericNodeDisplayData
9
+
10
+ _BaseNodeType = TypeVar("_BaseNodeType", bound=BaseNode)
11
+
12
+
13
+ class BaseNodeDisplay(BaseNodeVellumDisplay[_BaseNodeType], Generic[_BaseNodeType]):
14
+ def serialize(self, display_context: WorkflowDisplayContext, **kwargs: Any) -> JsonObject:
15
+ node = self._node
16
+ node_id = self.node_id
17
+
18
+ return {
19
+ "id": str(node_id),
20
+ "label": node.__qualname__,
21
+ "type": "GENERIC",
22
+ "display_data": self.get_generic_node_display_data().dict(),
23
+ "definition": self.get_definition().dict(),
24
+ "trigger": {
25
+ "id": str(uuid4_from_hash(f"{node_id}|trigger")),
26
+ "merge_behavior": node.Trigger.merge_behavior.value,
27
+ },
28
+ "ports": [],
29
+ "adornments": None,
30
+ "attributes": [],
31
+ }
32
+
33
+ def get_generic_node_display_data(self) -> GenericNodeDisplayData:
34
+ explicit_value = self._get_explicit_node_display_attr("display_data", GenericNodeDisplayData)
35
+ return explicit_value if explicit_value else GenericNodeDisplayData()
@@ -70,8 +70,8 @@ class BaseCodeExecutionNodeDisplay(BaseNodeVellumDisplay[_CodeExecutionNodeType]
70
70
 
71
71
  packages = raise_if_descriptor(node.packages)
72
72
 
73
- _, output_display = display_context.node_output_displays[node.Outputs.result]
74
- _, log_output_display = display_context.node_output_displays[node.Outputs.log]
73
+ _, output_display = display_context.global_node_output_displays[node.Outputs.result]
74
+ _, log_output_display = display_context.global_node_output_displays[node.Outputs.log]
75
75
 
76
76
  output_type = primitive_type_to_vellum_variable_type(node.get_output_type())
77
77
 
@@ -29,8 +29,8 @@ class BaseInlinePromptNodeDisplay(BaseNodeVellumDisplay[_InlinePromptNodeType],
29
29
  node_inputs, prompt_inputs = self._generate_node_and_prompt_inputs(node_id, node, display_context)
30
30
  input_variable_id_by_name = {prompt_input.key: prompt_input.id for prompt_input in prompt_inputs}
31
31
 
32
- _, output_display = display_context.node_output_displays[node.Outputs.text]
33
- _, array_display = display_context.node_output_displays[node.Outputs.results]
32
+ _, output_display = display_context.global_node_output_displays[node.Outputs.text]
33
+ _, array_display = display_context.global_node_output_displays[node.Outputs.results]
34
34
  node_blocks = raise_if_descriptor(node.blocks)
35
35
 
36
36
  return {
@@ -2,6 +2,7 @@ from uuid import UUID
2
2
  from typing import ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, cast
3
3
 
4
4
  from vellum import VellumVariable
5
+ from vellum.workflows.inputs.base import BaseInputs
5
6
  from vellum.workflows.nodes import InlineSubworkflowNode
6
7
  from vellum.workflows.types.core import JsonObject
7
8
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
@@ -60,12 +61,25 @@ class BaseInlineSubworkflowNodeDisplay(
60
61
  node: Type[InlineSubworkflowNode],
61
62
  display_context: WorkflowDisplayContext,
62
63
  ) -> Tuple[List[NodeInput], List[VellumVariable]]:
64
+ subworkflow = raise_if_descriptor(node.subworkflow)
65
+ subworkflow_inputs_class = subworkflow.get_inputs_class()
63
66
  subworkflow_inputs = raise_if_descriptor(node.subworkflow_inputs)
64
- subworkflow_entries = (
65
- [(variable_name, variable_value) for variable_name, variable_value in subworkflow_inputs.items()]
66
- if isinstance(subworkflow_inputs, dict)
67
- else [(variable_ref.name, variable_value) for variable_ref, variable_value in subworkflow_inputs]
68
- )
67
+
68
+ if isinstance(subworkflow_inputs, BaseInputs):
69
+ subworkflow_entries = [
70
+ (variable_ref.name, variable_value) for variable_ref, variable_value in subworkflow_inputs
71
+ ]
72
+ elif isinstance(subworkflow_inputs, dict):
73
+ subworkflow_entries = [
74
+ (variable_name, variable_value) for variable_name, variable_value in subworkflow_inputs.items()
75
+ ]
76
+ else:
77
+ subworkflow_entries = [
78
+ (descriptor.name, getattr(subworkflow_inputs_class, descriptor.name))
79
+ for descriptor in subworkflow_inputs_class
80
+ if hasattr(subworkflow_inputs_class, descriptor.name)
81
+ ]
82
+
69
83
  node_inputs = [
70
84
  create_node_input(
71
85
  node_id=node_id,
@@ -83,7 +97,7 @@ class BaseInlineSubworkflowNodeDisplay(
83
97
  key=descriptor.name,
84
98
  type=infer_vellum_variable_type(descriptor),
85
99
  )
86
- for descriptor in raise_if_descriptor(node.subworkflow).get_inputs_class()
100
+ for descriptor in subworkflow_inputs_class
87
101
  ]
88
102
 
89
103
  return node_inputs, workflow_inputs
@@ -33,6 +33,7 @@ class BaseMapNodeDisplay(BaseNodeVellumDisplay[_MapNodeType], Generic[_MapNodeTy
33
33
  subworkflow_display = get_workflow_display(
34
34
  base_display_class=display_context.workflow_display_class,
35
35
  workflow_class=subworkflow,
36
+ parent_display_context=display_context,
36
37
  )
37
38
  serialized_subworkflow = subworkflow_display.serialize()
38
39
 
@@ -38,8 +38,8 @@ class BasePromptDeploymentNodeDisplay(
38
38
  for variable_name, variable_value in prompt_inputs.items()
39
39
  ]
40
40
 
41
- _, output_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.text)]
42
- _, array_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.results)]
41
+ _, output_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.text)]
42
+ _, array_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.results)]
43
43
 
44
44
  # TODO: Pass through the name instead of retrieving the ID
45
45
  # https://app.shortcut.com/vellum/story/4702
@@ -39,8 +39,10 @@ class BaseSearchNodeDisplay(BaseNodeVellumDisplay[_SearchNodeType], Generic[_Sea
39
39
  node_id = self.node_id
40
40
  node_inputs = self._generate_search_node_inputs(node_id, node, display_context)
41
41
 
42
- _, results_output_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.results)]
43
- _, text_output_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.text)]
42
+ _, results_output_display = display_context.global_node_output_displays[
43
+ cast(OutputReference, node.Outputs.results)
44
+ ]
45
+ _, text_output_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.text)]
44
46
 
45
47
  return {
46
48
  "id": str(node_id),
@@ -56,7 +56,7 @@ class BaseTemplatingNodeDisplay(BaseNodeVellumDisplay[_TemplatingNodeType], Gene
56
56
  # Misc type ignore is due to `node.Outputs` being generic
57
57
  # https://app.shortcut.com/vellum/story/4784
58
58
  output_descriptor = node.Outputs.result # type: ignore [misc]
59
- _, output_display = display_context.node_output_displays[output_descriptor]
59
+ _, output_display = display_context.global_node_output_displays[output_descriptor]
60
60
  inferred_output_type = primitive_type_to_vellum_variable_type(output_descriptor)
61
61
 
62
62
  return {
@@ -91,18 +91,18 @@ def test_create_node_input_value_pointer_rules(
91
91
  entrypoint_node_source_handle_id=uuid4(),
92
92
  entrypoint_node_display=NodeDisplayData(),
93
93
  ),
94
- workflow_input_displays={
94
+ global_workflow_input_displays={
95
95
  cast(WorkflowInputReference, Inputs.example_workflow_input): WorkflowInputsVellumDisplayOverrides(
96
96
  id=UUID("a154c29d-fac0-4cd0-ba88-bc52034f5470"),
97
97
  ),
98
98
  },
99
- node_output_displays={
99
+ global_node_output_displays={
100
100
  cast(OutputReference, MyNodeA.Outputs.output): (
101
101
  MyNodeA,
102
102
  NodeOutputDisplay(id=UUID("4b16a629-11a1-4b3f-a965-a57b872d13b8"), name="output"),
103
103
  ),
104
104
  },
105
- node_displays={
105
+ global_node_displays={
106
106
  MyNodeA: MyNodeADisplay(),
107
107
  },
108
108
  ),
@@ -0,0 +1,28 @@
1
+ import pytest
2
+ from uuid import uuid4
3
+
4
+ from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
5
+ from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
6
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
7
+ from vellum_ee.workflows.display.vellum import NodeDisplayData, WorkflowMetaVellumDisplay
8
+ from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
9
+
10
+
11
+ @pytest.fixture()
12
+ def serialize_node():
13
+ def _serialize_node(node_class) -> dict:
14
+ node_display_class = get_node_display_class(BaseNodeDisplay, node_class)
15
+ node_display = node_display_class()
16
+
17
+ context: WorkflowDisplayContext = WorkflowDisplayContext(
18
+ workflow_display_class=VellumWorkflowDisplay,
19
+ workflow_display=WorkflowMetaVellumDisplay(
20
+ entrypoint_node_id=uuid4(),
21
+ entrypoint_node_source_handle_id=uuid4(),
22
+ entrypoint_node_display=NodeDisplayData(),
23
+ ),
24
+ node_displays={node_class: node_display},
25
+ )
26
+ return node_display.serialize(context)
27
+
28
+ return _serialize_node
@@ -0,0 +1,123 @@
1
+ from deepdiff import DeepDiff
2
+
3
+ from vellum.workflows.inputs.base import BaseInputs
4
+ from vellum.workflows.nodes.bases.base import BaseNode
5
+ from vellum.workflows.types.core import MergeBehavior
6
+
7
+
8
+ class Inputs(BaseInputs):
9
+ input: str
10
+
11
+
12
+ class BasicGenericNode(BaseNode):
13
+ class Outputs(BaseNode.Outputs):
14
+ output = Inputs.input
15
+
16
+
17
+ class AwaitAnyGenericNode(BaseNode):
18
+ class Outputs(BaseNode.Outputs):
19
+ output = Inputs.input
20
+
21
+ class Trigger(BaseNode.Trigger):
22
+ merge_behavior = MergeBehavior.AWAIT_ANY
23
+
24
+
25
+ class AwaitAllGenericNode(BaseNode):
26
+ class Outputs(BaseNode.Outputs):
27
+ output = Inputs.input
28
+
29
+ class Trigger(BaseNode.Trigger):
30
+ merge_behavior = MergeBehavior.AWAIT_ALL
31
+
32
+
33
+ def test_serialize_node__basic(serialize_node):
34
+ serialized_node = serialize_node(BasicGenericNode)
35
+ assert not DeepDiff(
36
+ {
37
+ "id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
38
+ "label": "BasicGenericNode",
39
+ "type": "GENERIC",
40
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
41
+ "definition": {
42
+ "name": "BasicGenericNode",
43
+ "module": [
44
+ "vellum_ee",
45
+ "workflows",
46
+ "display",
47
+ "tests",
48
+ "workflow_serialization",
49
+ "generic_nodes",
50
+ "test_trigger_serialization",
51
+ ],
52
+ "bases": [{"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"]}],
53
+ },
54
+ "trigger": {"id": "9d3a1b3d-4a38-4f2e-bbf1-dd8be152bce8", "merge_behavior": "AWAIT_ANY"},
55
+ "ports": [],
56
+ "adornments": None,
57
+ "attributes": [],
58
+ },
59
+ serialized_node,
60
+ ignore_order=True,
61
+ )
62
+
63
+
64
+ def test_serialize_node__await_any(serialize_node):
65
+ serialized_node = serialize_node(AwaitAnyGenericNode)
66
+ assert not DeepDiff(
67
+ {
68
+ "id": "0ba67f76-aaff-4bd4-a20f-73a32ef5810d",
69
+ "label": "AwaitAnyGenericNode",
70
+ "type": "GENERIC",
71
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
72
+ "definition": {
73
+ "name": "AwaitAnyGenericNode",
74
+ "module": [
75
+ "vellum_ee",
76
+ "workflows",
77
+ "display",
78
+ "tests",
79
+ "workflow_serialization",
80
+ "generic_nodes",
81
+ "test_trigger_serialization",
82
+ ],
83
+ "bases": [{"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"]}],
84
+ },
85
+ "trigger": {"id": "ffa72187-9a18-453f-ae55-b77aad332630", "merge_behavior": "AWAIT_ANY"},
86
+ "ports": [],
87
+ "adornments": None,
88
+ "attributes": [],
89
+ },
90
+ serialized_node,
91
+ ignore_order=True,
92
+ )
93
+
94
+
95
+ def test_serialize_node__await_all(serialize_node):
96
+ serialized_node = serialize_node(AwaitAllGenericNode)
97
+ assert not DeepDiff(
98
+ {
99
+ "id": "09d06cd3-06ea-40cc-afd8-17ad88542271",
100
+ "label": "AwaitAllGenericNode",
101
+ "type": "GENERIC",
102
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
103
+ "definition": {
104
+ "name": "AwaitAllGenericNode",
105
+ "module": [
106
+ "vellum_ee",
107
+ "workflows",
108
+ "display",
109
+ "tests",
110
+ "workflow_serialization",
111
+ "generic_nodes",
112
+ "test_trigger_serialization",
113
+ ],
114
+ "bases": [{"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"]}],
115
+ },
116
+ "trigger": {"id": "62074276-c817-476d-b59d-da523ae3f218", "merge_behavior": "AWAIT_ALL"},
117
+ "ports": [],
118
+ "adornments": None,
119
+ "attributes": [],
120
+ },
121
+ serialized_node,
122
+ ignore_order=True,
123
+ )
@@ -1,5 +1,4 @@
1
1
  import pytest
2
- from unittest import mock
3
2
 
4
3
  from deepdiff import DeepDiff
5
4
 
@@ -21,7 +20,6 @@ from vellum.workflows.expressions.less_than import LessThanExpression
21
20
  from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
22
21
  from vellum.workflows.expressions.not_between import NotBetweenExpression
23
22
  from vellum.workflows.expressions.not_in import NotInExpression
24
- from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
25
23
  from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
26
24
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
27
25
 
@@ -33,12 +31,7 @@ def test_serialize_workflow():
33
31
  # GIVEN a Workflow that uses a ConditionalNode
34
32
  # WHEN we serialize it
35
33
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=CategoryWorkflow)
36
-
37
- # TODO: Support serialization of BaseNode
38
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
39
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
40
- mocked_serialize.return_value = {"type": "MOCKED"}
41
- serialized_workflow: dict = workflow_display.serialize()
34
+ serialized_workflow: dict = workflow_display.serialize()
42
35
 
43
36
  # THEN we should get a serialized representation of the Workflow
44
37
  assert serialized_workflow.keys() == {
@@ -455,26 +448,8 @@ def test_serialize_workflow():
455
448
  ignore_order=True,
456
449
  )
457
450
 
458
- assert not DeepDiff(
459
- [
460
- {
461
- "type": "MOCKED",
462
- },
463
- {
464
- "type": "MOCKED",
465
- },
466
- {
467
- "type": "MOCKED",
468
- },
469
- {
470
- "type": "MOCKED",
471
- },
472
- {
473
- "type": "MOCKED",
474
- },
475
- ],
476
- workflow_raw_data["nodes"][2:7],
477
- )
451
+ passthrough_nodes = [node for node in workflow_raw_data["nodes"] if node["type"] == "GENERIC"]
452
+ assert len(passthrough_nodes) == 5
478
453
 
479
454
  assert not DeepDiff(
480
455
  [
@@ -917,12 +892,7 @@ def test_conditional_node_serialize_all_operators_with_lhs_and_rhs(descriptor, o
917
892
  workflow_cls = create_simple_workflow(descriptor)
918
893
 
919
894
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=workflow_cls)
920
-
921
- # TODO: Support serialization of BaseNode
922
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
923
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
924
- mocked_serialize.return_value = {"type": "MOCKED"}
925
- serialized_workflow: dict = workflow_display.serialize()
895
+ serialized_workflow: dict = workflow_display.serialize()
926
896
 
927
897
  # THEN we should get a serialized representation of the Workflow
928
898
  assert serialized_workflow.keys() == {
@@ -1041,12 +1011,7 @@ def test_conditional_node_serialize_all_operators_with_expression(descriptor, op
1041
1011
  workflow_cls = create_simple_workflow(descriptor)
1042
1012
 
1043
1013
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=workflow_cls)
1044
-
1045
- # TODO: Support serialization of BaseNode
1046
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
1047
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
1048
- mocked_serialize.return_value = {"type": "MOCKED"}
1049
- serialized_workflow: dict = workflow_display.serialize()
1014
+ serialized_workflow: dict = workflow_display.serialize()
1050
1015
 
1051
1016
  # THEN we should get a serialized representation of the Workflow
1052
1017
  assert serialized_workflow.keys() == {
@@ -1152,12 +1117,7 @@ def test_conditional_node_serialize_all_operators_with_value_and_start_and_end(d
1152
1117
  workflow_cls = create_simple_workflow(descriptor)
1153
1118
 
1154
1119
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=workflow_cls)
1155
-
1156
- # TODO: Support serialization of BaseNode
1157
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
1158
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
1159
- mocked_serialize.return_value = {"type": "MOCKED"}
1160
- serialized_workflow: dict = workflow_display.serialize()
1120
+ serialized_workflow: dict = workflow_display.serialize()
1161
1121
 
1162
1122
  # THEN we should get a serialized representation of the Workflow
1163
1123
  assert serialized_workflow.keys() == {
@@ -1,8 +1,5 @@
1
- from unittest import mock
2
-
3
1
  from deepdiff import DeepDiff
4
2
 
5
- from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
6
3
  from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
7
4
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
8
5
 
@@ -16,11 +13,7 @@ def test_serialize_workflow():
16
13
  base_display_class=VellumWorkflowDisplay, workflow_class=BasicErrorNodeWorkflow
17
14
  )
18
15
 
19
- # TODO: Support serialization of BaseNode
20
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
21
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
22
- mocked_serialize.return_value = {"type": "MOCKED"}
23
- serialized_workflow: dict = workflow_display.serialize()
16
+ serialized_workflow: dict = workflow_display.serialize()
24
17
 
25
18
  # THEN we should get a serialized representation of the Workflow
26
19
  assert serialized_workflow.keys() == {
@@ -129,23 +122,8 @@ def test_serialize_workflow():
129
122
  ignore_order=True,
130
123
  )
131
124
 
132
- mocked_base_nodes = [
133
- node
134
- for i, node in enumerate(workflow_raw_data["nodes"])
135
- if i != error_index and i != 0 and i != len(workflow_raw_data["nodes"]) - 1
136
- ]
137
-
138
- assert not DeepDiff(
139
- [
140
- {
141
- "type": "MOCKED",
142
- },
143
- {
144
- "type": "MOCKED",
145
- },
146
- ],
147
- mocked_base_nodes,
148
- )
125
+ passthrough_nodes = [node for node in workflow_raw_data["nodes"] if node["type"] == "GENERIC"]
126
+ assert len(passthrough_nodes) == 2
149
127
 
150
128
  terminal_node = workflow_raw_data["nodes"][-1]
151
129
  assert not DeepDiff(
@@ -0,0 +1,168 @@
1
+ from deepdiff import DeepDiff
2
+
3
+ from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
4
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
5
+
6
+ from tests.workflows.basic_generic_node.workflow import BasicGenericNodeWorkflow
7
+
8
+
9
+ def test_serialize_workflow(vellum_client):
10
+ # GIVEN a Workflow that uses a generic node
11
+ # WHEN we serialize it
12
+ workflow_display = get_workflow_display(
13
+ base_display_class=VellumWorkflowDisplay, workflow_class=BasicGenericNodeWorkflow
14
+ )
15
+
16
+ serialized_workflow: dict = workflow_display.serialize()
17
+
18
+ # THEN we should get a serialized representation of the Workflow
19
+ assert serialized_workflow.keys() == {
20
+ "workflow_raw_data",
21
+ "input_variables",
22
+ "output_variables",
23
+ }
24
+
25
+ # AND its input variables should be what we expect
26
+ input_variables = serialized_workflow["input_variables"]
27
+ assert len(input_variables) == 1
28
+ assert not DeepDiff(
29
+ [
30
+ {
31
+ "id": "a07c2273-34a7-42b5-bcad-143b6127cc8a",
32
+ "key": "input",
33
+ "type": "STRING",
34
+ "default": None,
35
+ "required": True,
36
+ "extensions": {"color": None},
37
+ },
38
+ ],
39
+ input_variables,
40
+ ignore_order=True,
41
+ )
42
+
43
+ # AND its output variables should be what we expect
44
+ output_variables = serialized_workflow["output_variables"]
45
+ assert len(output_variables) == 1
46
+ assert not DeepDiff(
47
+ [
48
+ {"id": "2b6389d0-266a-4be4-843e-4e543dd3d727", "key": "output", "type": "STRING"},
49
+ ],
50
+ output_variables,
51
+ ignore_order=True,
52
+ )
53
+
54
+ # AND its raw data should be what we expect
55
+ workflow_raw_data = serialized_workflow["workflow_raw_data"]
56
+ assert workflow_raw_data.keys() == {"edges", "nodes", "display_data", "definition"}
57
+ assert len(workflow_raw_data["edges"]) == 2
58
+ assert len(workflow_raw_data["nodes"]) == 3
59
+
60
+ # AND each node should be serialized correctly
61
+ entrypoint_node = workflow_raw_data["nodes"][0]
62
+ assert entrypoint_node == {
63
+ "id": "f1e4678f-c470-400b-a40e-c8922cc99a86",
64
+ "type": "ENTRYPOINT",
65
+ "inputs": [],
66
+ "data": {"label": "Entrypoint Node", "source_handle_id": "40201804-8beb-43ad-8873-a027759512f1"},
67
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
68
+ "definition": {
69
+ "name": "BaseNode",
70
+ "module": [
71
+ "vellum",
72
+ "workflows",
73
+ "nodes",
74
+ "bases",
75
+ "base",
76
+ ],
77
+ "bases": [],
78
+ },
79
+ }
80
+
81
+ api_node = workflow_raw_data["nodes"][1]
82
+ assert api_node["id"] == "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f"
83
+
84
+ final_output_node = workflow_raw_data["nodes"][2]
85
+ assert not DeepDiff(
86
+ {
87
+ "id": "50e3b446-afcd-4a5d-8c6f-5f05eaf2200e",
88
+ "type": "TERMINAL",
89
+ "data": {
90
+ "label": "Final Output",
91
+ "name": "output",
92
+ "target_handle_id": "8bd9f4f3-9f66-4d95-8e84-529b0002c531",
93
+ "output_id": "2b6389d0-266a-4be4-843e-4e543dd3d727",
94
+ "output_type": "STRING",
95
+ "node_input_id": "7a9f2d3a-0b23-4bd4-b567-e9493135b727",
96
+ },
97
+ "inputs": [
98
+ {
99
+ "id": "7a9f2d3a-0b23-4bd4-b567-e9493135b727",
100
+ "key": "node_input",
101
+ "value": {
102
+ "rules": [
103
+ {
104
+ "type": "NODE_OUTPUT",
105
+ "data": {
106
+ "node_id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
107
+ "output_id": "0a9c7a80-fc89-4a71-aac0-66489e4ddb85",
108
+ },
109
+ }
110
+ ],
111
+ "combinator": "OR",
112
+ },
113
+ }
114
+ ],
115
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
116
+ "definition": {
117
+ "name": "FinalOutputNode",
118
+ "module": ["vellum", "workflows", "nodes", "displayable", "final_output_node", "node"],
119
+ "bases": [
120
+ {"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"], "bases": []}
121
+ ],
122
+ },
123
+ },
124
+ final_output_node,
125
+ ignore_order=True,
126
+ )
127
+
128
+ # AND each edge should be serialized correctly
129
+ serialized_edges = workflow_raw_data["edges"]
130
+ assert not DeepDiff(
131
+ [
132
+ {
133
+ "id": "445dd2de-82b2-482b-89f6-5f49d8eb21a9",
134
+ "source_node_id": "f1e4678f-c470-400b-a40e-c8922cc99a86",
135
+ "source_handle_id": "40201804-8beb-43ad-8873-a027759512f1",
136
+ "target_node_id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
137
+ "target_handle_id": "b7bfb298-959a-4d2b-8b85-bbd0d2522703",
138
+ "type": "DEFAULT",
139
+ },
140
+ {
141
+ "id": "b741c861-cf67-4649-b9ef-b43a4add72b1",
142
+ "source_node_id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
143
+ "source_handle_id": "89dccfa5-cc1a-4612-bd87-86cb444f6dd4",
144
+ "target_node_id": "50e3b446-afcd-4a5d-8c6f-5f05eaf2200e",
145
+ "target_handle_id": "8bd9f4f3-9f66-4d95-8e84-529b0002c531",
146
+ "type": "DEFAULT",
147
+ },
148
+ ],
149
+ serialized_edges,
150
+ ignore_order=True,
151
+ )
152
+
153
+ # AND the display data should be what we expect
154
+ display_data = workflow_raw_data["display_data"]
155
+ assert display_data == {
156
+ "viewport": {
157
+ "x": 0.0,
158
+ "y": 0.0,
159
+ "zoom": 1.0,
160
+ }
161
+ }
162
+
163
+ # AND the definition should be what we expect
164
+ definition = workflow_raw_data["definition"]
165
+ assert definition == {
166
+ "name": "BasicGenericNodeWorkflow",
167
+ "module": ["tests", "workflows", "basic_generic_node", "workflow"],
168
+ }