vellum-ai 0.12.14__py3-none-any.whl → 0.12.16__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (59) hide show
  1. vellum/__init__.py +6 -0
  2. vellum/client/__init__.py +2 -6
  3. vellum/client/core/client_wrapper.py +1 -1
  4. vellum/client/environment.py +3 -3
  5. vellum/client/resources/ad_hoc/client.py +2 -6
  6. vellum/client/resources/container_images/client.py +0 -8
  7. vellum/client/resources/metric_definitions/client.py +2 -6
  8. vellum/client/resources/workflows/client.py +8 -8
  9. vellum/client/types/__init__.py +6 -0
  10. vellum/client/types/audio_prompt_block.py +29 -0
  11. vellum/client/types/function_call_prompt_block.py +30 -0
  12. vellum/client/types/image_prompt_block.py +29 -0
  13. vellum/client/types/prompt_block.py +12 -1
  14. vellum/client/types/workflow_push_response.py +1 -0
  15. vellum/prompts/blocks/compilation.py +43 -0
  16. vellum/types/audio_prompt_block.py +3 -0
  17. vellum/types/function_call_prompt_block.py +3 -0
  18. vellum/types/image_prompt_block.py +3 -0
  19. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +10 -2
  20. vellum/workflows/nodes/core/inline_subworkflow_node/tests/test_node.py +16 -0
  21. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/METADATA +11 -9
  22. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/RECORD +59 -48
  23. vellum_cli/__init__.py +14 -0
  24. vellum_cli/config.py +4 -0
  25. vellum_cli/pull.py +20 -5
  26. vellum_cli/push.py +33 -4
  27. vellum_cli/tests/test_pull.py +19 -1
  28. vellum_cli/tests/test_push.py +63 -0
  29. vellum_ee/workflows/display/nodes/vellum/__init__.py +2 -0
  30. vellum_ee/workflows/display/nodes/vellum/api_node.py +3 -3
  31. vellum_ee/workflows/display/nodes/vellum/base_node.py +35 -0
  32. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -2
  33. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +2 -2
  34. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +20 -6
  35. vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -0
  36. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -2
  37. vellum_ee/workflows/display/nodes/vellum/search_node.py +4 -2
  38. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  39. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +3 -3
  40. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/__init__.py +0 -0
  41. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +28 -0
  42. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_trigger_serialization.py +123 -0
  43. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +6 -46
  44. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +3 -25
  45. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +168 -0
  46. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +18 -10
  47. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +18 -10
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +3 -25
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -8
  50. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +13 -27
  51. vellum_ee/workflows/display/types.py +5 -1
  52. vellum_ee/workflows/display/utils/vellum.py +3 -3
  53. vellum_ee/workflows/display/vellum.py +4 -0
  54. vellum_ee/workflows/display/workflows/base_workflow_display.py +44 -16
  55. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +3 -0
  56. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +7 -8
  57. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/LICENSE +0 -0
  58. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/WHEEL +0 -0
  59. {vellum_ai-0.12.14.dist-info → vellum_ai-0.12.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,35 @@
1
+ from typing import Any, Generic, TypeVar
2
+
3
+ from vellum.workflows.nodes.bases.base import BaseNode
4
+ from vellum.workflows.types.core import JsonObject
5
+ from vellum.workflows.utils.uuids import uuid4_from_hash
6
+ from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
7
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
8
+ from vellum_ee.workflows.display.vellum import GenericNodeDisplayData
9
+
10
+ _BaseNodeType = TypeVar("_BaseNodeType", bound=BaseNode)
11
+
12
+
13
+ class BaseNodeDisplay(BaseNodeVellumDisplay[_BaseNodeType], Generic[_BaseNodeType]):
14
+ def serialize(self, display_context: WorkflowDisplayContext, **kwargs: Any) -> JsonObject:
15
+ node = self._node
16
+ node_id = self.node_id
17
+
18
+ return {
19
+ "id": str(node_id),
20
+ "label": node.__qualname__,
21
+ "type": "GENERIC",
22
+ "display_data": self.get_generic_node_display_data().dict(),
23
+ "definition": self.get_definition().dict(),
24
+ "trigger": {
25
+ "id": str(uuid4_from_hash(f"{node_id}|trigger")),
26
+ "merge_behavior": node.Trigger.merge_behavior.value,
27
+ },
28
+ "ports": [],
29
+ "adornments": None,
30
+ "attributes": [],
31
+ }
32
+
33
+ def get_generic_node_display_data(self) -> GenericNodeDisplayData:
34
+ explicit_value = self._get_explicit_node_display_attr("display_data", GenericNodeDisplayData)
35
+ return explicit_value if explicit_value else GenericNodeDisplayData()
@@ -70,8 +70,8 @@ class BaseCodeExecutionNodeDisplay(BaseNodeVellumDisplay[_CodeExecutionNodeType]
70
70
 
71
71
  packages = raise_if_descriptor(node.packages)
72
72
 
73
- _, output_display = display_context.node_output_displays[node.Outputs.result]
74
- _, log_output_display = display_context.node_output_displays[node.Outputs.log]
73
+ _, output_display = display_context.global_node_output_displays[node.Outputs.result]
74
+ _, log_output_display = display_context.global_node_output_displays[node.Outputs.log]
75
75
 
76
76
  output_type = primitive_type_to_vellum_variable_type(node.get_output_type())
77
77
 
@@ -29,8 +29,8 @@ class BaseInlinePromptNodeDisplay(BaseNodeVellumDisplay[_InlinePromptNodeType],
29
29
  node_inputs, prompt_inputs = self._generate_node_and_prompt_inputs(node_id, node, display_context)
30
30
  input_variable_id_by_name = {prompt_input.key: prompt_input.id for prompt_input in prompt_inputs}
31
31
 
32
- _, output_display = display_context.node_output_displays[node.Outputs.text]
33
- _, array_display = display_context.node_output_displays[node.Outputs.results]
32
+ _, output_display = display_context.global_node_output_displays[node.Outputs.text]
33
+ _, array_display = display_context.global_node_output_displays[node.Outputs.results]
34
34
  node_blocks = raise_if_descriptor(node.blocks)
35
35
 
36
36
  return {
@@ -2,6 +2,7 @@ from uuid import UUID
2
2
  from typing import ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, cast
3
3
 
4
4
  from vellum import VellumVariable
5
+ from vellum.workflows.inputs.base import BaseInputs
5
6
  from vellum.workflows.nodes import InlineSubworkflowNode
6
7
  from vellum.workflows.types.core import JsonObject
7
8
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
@@ -60,12 +61,25 @@ class BaseInlineSubworkflowNodeDisplay(
60
61
  node: Type[InlineSubworkflowNode],
61
62
  display_context: WorkflowDisplayContext,
62
63
  ) -> Tuple[List[NodeInput], List[VellumVariable]]:
64
+ subworkflow = raise_if_descriptor(node.subworkflow)
65
+ subworkflow_inputs_class = subworkflow.get_inputs_class()
63
66
  subworkflow_inputs = raise_if_descriptor(node.subworkflow_inputs)
64
- subworkflow_entries = (
65
- [(variable_name, variable_value) for variable_name, variable_value in subworkflow_inputs.items()]
66
- if isinstance(subworkflow_inputs, dict)
67
- else [(variable_ref.name, variable_value) for variable_ref, variable_value in subworkflow_inputs]
68
- )
67
+
68
+ if isinstance(subworkflow_inputs, BaseInputs):
69
+ subworkflow_entries = [
70
+ (variable_ref.name, variable_value) for variable_ref, variable_value in subworkflow_inputs
71
+ ]
72
+ elif isinstance(subworkflow_inputs, dict):
73
+ subworkflow_entries = [
74
+ (variable_name, variable_value) for variable_name, variable_value in subworkflow_inputs.items()
75
+ ]
76
+ else:
77
+ subworkflow_entries = [
78
+ (descriptor.name, getattr(subworkflow_inputs_class, descriptor.name))
79
+ for descriptor in subworkflow_inputs_class
80
+ if hasattr(subworkflow_inputs_class, descriptor.name)
81
+ ]
82
+
69
83
  node_inputs = [
70
84
  create_node_input(
71
85
  node_id=node_id,
@@ -83,7 +97,7 @@ class BaseInlineSubworkflowNodeDisplay(
83
97
  key=descriptor.name,
84
98
  type=infer_vellum_variable_type(descriptor),
85
99
  )
86
- for descriptor in raise_if_descriptor(node.subworkflow).get_inputs_class()
100
+ for descriptor in subworkflow_inputs_class
87
101
  ]
88
102
 
89
103
  return node_inputs, workflow_inputs
@@ -33,6 +33,7 @@ class BaseMapNodeDisplay(BaseNodeVellumDisplay[_MapNodeType], Generic[_MapNodeTy
33
33
  subworkflow_display = get_workflow_display(
34
34
  base_display_class=display_context.workflow_display_class,
35
35
  workflow_class=subworkflow,
36
+ parent_display_context=display_context,
36
37
  )
37
38
  serialized_subworkflow = subworkflow_display.serialize()
38
39
 
@@ -38,8 +38,8 @@ class BasePromptDeploymentNodeDisplay(
38
38
  for variable_name, variable_value in prompt_inputs.items()
39
39
  ]
40
40
 
41
- _, output_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.text)]
42
- _, array_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.results)]
41
+ _, output_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.text)]
42
+ _, array_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.results)]
43
43
 
44
44
  # TODO: Pass through the name instead of retrieving the ID
45
45
  # https://app.shortcut.com/vellum/story/4702
@@ -39,8 +39,10 @@ class BaseSearchNodeDisplay(BaseNodeVellumDisplay[_SearchNodeType], Generic[_Sea
39
39
  node_id = self.node_id
40
40
  node_inputs = self._generate_search_node_inputs(node_id, node, display_context)
41
41
 
42
- _, results_output_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.results)]
43
- _, text_output_display = display_context.node_output_displays[cast(OutputReference, node.Outputs.text)]
42
+ _, results_output_display = display_context.global_node_output_displays[
43
+ cast(OutputReference, node.Outputs.results)
44
+ ]
45
+ _, text_output_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.text)]
44
46
 
45
47
  return {
46
48
  "id": str(node_id),
@@ -56,7 +56,7 @@ class BaseTemplatingNodeDisplay(BaseNodeVellumDisplay[_TemplatingNodeType], Gene
56
56
  # Misc type ignore is due to `node.Outputs` being generic
57
57
  # https://app.shortcut.com/vellum/story/4784
58
58
  output_descriptor = node.Outputs.result # type: ignore [misc]
59
- _, output_display = display_context.node_output_displays[output_descriptor]
59
+ _, output_display = display_context.global_node_output_displays[output_descriptor]
60
60
  inferred_output_type = primitive_type_to_vellum_variable_type(output_descriptor)
61
61
 
62
62
  return {
@@ -91,18 +91,18 @@ def test_create_node_input_value_pointer_rules(
91
91
  entrypoint_node_source_handle_id=uuid4(),
92
92
  entrypoint_node_display=NodeDisplayData(),
93
93
  ),
94
- workflow_input_displays={
94
+ global_workflow_input_displays={
95
95
  cast(WorkflowInputReference, Inputs.example_workflow_input): WorkflowInputsVellumDisplayOverrides(
96
96
  id=UUID("a154c29d-fac0-4cd0-ba88-bc52034f5470"),
97
97
  ),
98
98
  },
99
- node_output_displays={
99
+ global_node_output_displays={
100
100
  cast(OutputReference, MyNodeA.Outputs.output): (
101
101
  MyNodeA,
102
102
  NodeOutputDisplay(id=UUID("4b16a629-11a1-4b3f-a965-a57b872d13b8"), name="output"),
103
103
  ),
104
104
  },
105
- node_displays={
105
+ global_node_displays={
106
106
  MyNodeA: MyNodeADisplay(),
107
107
  },
108
108
  ),
@@ -0,0 +1,28 @@
1
+ import pytest
2
+ from uuid import uuid4
3
+
4
+ from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
5
+ from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
6
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
7
+ from vellum_ee.workflows.display.vellum import NodeDisplayData, WorkflowMetaVellumDisplay
8
+ from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
9
+
10
+
11
+ @pytest.fixture()
12
+ def serialize_node():
13
+ def _serialize_node(node_class) -> dict:
14
+ node_display_class = get_node_display_class(BaseNodeDisplay, node_class)
15
+ node_display = node_display_class()
16
+
17
+ context: WorkflowDisplayContext = WorkflowDisplayContext(
18
+ workflow_display_class=VellumWorkflowDisplay,
19
+ workflow_display=WorkflowMetaVellumDisplay(
20
+ entrypoint_node_id=uuid4(),
21
+ entrypoint_node_source_handle_id=uuid4(),
22
+ entrypoint_node_display=NodeDisplayData(),
23
+ ),
24
+ node_displays={node_class: node_display},
25
+ )
26
+ return node_display.serialize(context)
27
+
28
+ return _serialize_node
@@ -0,0 +1,123 @@
1
+ from deepdiff import DeepDiff
2
+
3
+ from vellum.workflows.inputs.base import BaseInputs
4
+ from vellum.workflows.nodes.bases.base import BaseNode
5
+ from vellum.workflows.types.core import MergeBehavior
6
+
7
+
8
+ class Inputs(BaseInputs):
9
+ input: str
10
+
11
+
12
+ class BasicGenericNode(BaseNode):
13
+ class Outputs(BaseNode.Outputs):
14
+ output = Inputs.input
15
+
16
+
17
+ class AwaitAnyGenericNode(BaseNode):
18
+ class Outputs(BaseNode.Outputs):
19
+ output = Inputs.input
20
+
21
+ class Trigger(BaseNode.Trigger):
22
+ merge_behavior = MergeBehavior.AWAIT_ANY
23
+
24
+
25
+ class AwaitAllGenericNode(BaseNode):
26
+ class Outputs(BaseNode.Outputs):
27
+ output = Inputs.input
28
+
29
+ class Trigger(BaseNode.Trigger):
30
+ merge_behavior = MergeBehavior.AWAIT_ALL
31
+
32
+
33
+ def test_serialize_node__basic(serialize_node):
34
+ serialized_node = serialize_node(BasicGenericNode)
35
+ assert not DeepDiff(
36
+ {
37
+ "id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
38
+ "label": "BasicGenericNode",
39
+ "type": "GENERIC",
40
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
41
+ "definition": {
42
+ "name": "BasicGenericNode",
43
+ "module": [
44
+ "vellum_ee",
45
+ "workflows",
46
+ "display",
47
+ "tests",
48
+ "workflow_serialization",
49
+ "generic_nodes",
50
+ "test_trigger_serialization",
51
+ ],
52
+ "bases": [{"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"]}],
53
+ },
54
+ "trigger": {"id": "9d3a1b3d-4a38-4f2e-bbf1-dd8be152bce8", "merge_behavior": "AWAIT_ANY"},
55
+ "ports": [],
56
+ "adornments": None,
57
+ "attributes": [],
58
+ },
59
+ serialized_node,
60
+ ignore_order=True,
61
+ )
62
+
63
+
64
+ def test_serialize_node__await_any(serialize_node):
65
+ serialized_node = serialize_node(AwaitAnyGenericNode)
66
+ assert not DeepDiff(
67
+ {
68
+ "id": "0ba67f76-aaff-4bd4-a20f-73a32ef5810d",
69
+ "label": "AwaitAnyGenericNode",
70
+ "type": "GENERIC",
71
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
72
+ "definition": {
73
+ "name": "AwaitAnyGenericNode",
74
+ "module": [
75
+ "vellum_ee",
76
+ "workflows",
77
+ "display",
78
+ "tests",
79
+ "workflow_serialization",
80
+ "generic_nodes",
81
+ "test_trigger_serialization",
82
+ ],
83
+ "bases": [{"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"]}],
84
+ },
85
+ "trigger": {"id": "ffa72187-9a18-453f-ae55-b77aad332630", "merge_behavior": "AWAIT_ANY"},
86
+ "ports": [],
87
+ "adornments": None,
88
+ "attributes": [],
89
+ },
90
+ serialized_node,
91
+ ignore_order=True,
92
+ )
93
+
94
+
95
+ def test_serialize_node__await_all(serialize_node):
96
+ serialized_node = serialize_node(AwaitAllGenericNode)
97
+ assert not DeepDiff(
98
+ {
99
+ "id": "09d06cd3-06ea-40cc-afd8-17ad88542271",
100
+ "label": "AwaitAllGenericNode",
101
+ "type": "GENERIC",
102
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
103
+ "definition": {
104
+ "name": "AwaitAllGenericNode",
105
+ "module": [
106
+ "vellum_ee",
107
+ "workflows",
108
+ "display",
109
+ "tests",
110
+ "workflow_serialization",
111
+ "generic_nodes",
112
+ "test_trigger_serialization",
113
+ ],
114
+ "bases": [{"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"]}],
115
+ },
116
+ "trigger": {"id": "62074276-c817-476d-b59d-da523ae3f218", "merge_behavior": "AWAIT_ALL"},
117
+ "ports": [],
118
+ "adornments": None,
119
+ "attributes": [],
120
+ },
121
+ serialized_node,
122
+ ignore_order=True,
123
+ )
@@ -1,5 +1,4 @@
1
1
  import pytest
2
- from unittest import mock
3
2
 
4
3
  from deepdiff import DeepDiff
5
4
 
@@ -21,7 +20,6 @@ from vellum.workflows.expressions.less_than import LessThanExpression
21
20
  from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
22
21
  from vellum.workflows.expressions.not_between import NotBetweenExpression
23
22
  from vellum.workflows.expressions.not_in import NotInExpression
24
- from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
25
23
  from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
26
24
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
27
25
 
@@ -33,12 +31,7 @@ def test_serialize_workflow():
33
31
  # GIVEN a Workflow that uses a ConditionalNode
34
32
  # WHEN we serialize it
35
33
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=CategoryWorkflow)
36
-
37
- # TODO: Support serialization of BaseNode
38
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
39
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
40
- mocked_serialize.return_value = {"type": "MOCKED"}
41
- serialized_workflow: dict = workflow_display.serialize()
34
+ serialized_workflow: dict = workflow_display.serialize()
42
35
 
43
36
  # THEN we should get a serialized representation of the Workflow
44
37
  assert serialized_workflow.keys() == {
@@ -455,26 +448,8 @@ def test_serialize_workflow():
455
448
  ignore_order=True,
456
449
  )
457
450
 
458
- assert not DeepDiff(
459
- [
460
- {
461
- "type": "MOCKED",
462
- },
463
- {
464
- "type": "MOCKED",
465
- },
466
- {
467
- "type": "MOCKED",
468
- },
469
- {
470
- "type": "MOCKED",
471
- },
472
- {
473
- "type": "MOCKED",
474
- },
475
- ],
476
- workflow_raw_data["nodes"][2:7],
477
- )
451
+ passthrough_nodes = [node for node in workflow_raw_data["nodes"] if node["type"] == "GENERIC"]
452
+ assert len(passthrough_nodes) == 5
478
453
 
479
454
  assert not DeepDiff(
480
455
  [
@@ -917,12 +892,7 @@ def test_conditional_node_serialize_all_operators_with_lhs_and_rhs(descriptor, o
917
892
  workflow_cls = create_simple_workflow(descriptor)
918
893
 
919
894
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=workflow_cls)
920
-
921
- # TODO: Support serialization of BaseNode
922
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
923
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
924
- mocked_serialize.return_value = {"type": "MOCKED"}
925
- serialized_workflow: dict = workflow_display.serialize()
895
+ serialized_workflow: dict = workflow_display.serialize()
926
896
 
927
897
  # THEN we should get a serialized representation of the Workflow
928
898
  assert serialized_workflow.keys() == {
@@ -1041,12 +1011,7 @@ def test_conditional_node_serialize_all_operators_with_expression(descriptor, op
1041
1011
  workflow_cls = create_simple_workflow(descriptor)
1042
1012
 
1043
1013
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=workflow_cls)
1044
-
1045
- # TODO: Support serialization of BaseNode
1046
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
1047
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
1048
- mocked_serialize.return_value = {"type": "MOCKED"}
1049
- serialized_workflow: dict = workflow_display.serialize()
1014
+ serialized_workflow: dict = workflow_display.serialize()
1050
1015
 
1051
1016
  # THEN we should get a serialized representation of the Workflow
1052
1017
  assert serialized_workflow.keys() == {
@@ -1152,12 +1117,7 @@ def test_conditional_node_serialize_all_operators_with_value_and_start_and_end(d
1152
1117
  workflow_cls = create_simple_workflow(descriptor)
1153
1118
 
1154
1119
  workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=workflow_cls)
1155
-
1156
- # TODO: Support serialization of BaseNode
1157
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
1158
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
1159
- mocked_serialize.return_value = {"type": "MOCKED"}
1160
- serialized_workflow: dict = workflow_display.serialize()
1120
+ serialized_workflow: dict = workflow_display.serialize()
1161
1121
 
1162
1122
  # THEN we should get a serialized representation of the Workflow
1163
1123
  assert serialized_workflow.keys() == {
@@ -1,8 +1,5 @@
1
- from unittest import mock
2
-
3
1
  from deepdiff import DeepDiff
4
2
 
5
- from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
6
3
  from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
7
4
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
8
5
 
@@ -16,11 +13,7 @@ def test_serialize_workflow():
16
13
  base_display_class=VellumWorkflowDisplay, workflow_class=BasicErrorNodeWorkflow
17
14
  )
18
15
 
19
- # TODO: Support serialization of BaseNode
20
- # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
21
- with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
22
- mocked_serialize.return_value = {"type": "MOCKED"}
23
- serialized_workflow: dict = workflow_display.serialize()
16
+ serialized_workflow: dict = workflow_display.serialize()
24
17
 
25
18
  # THEN we should get a serialized representation of the Workflow
26
19
  assert serialized_workflow.keys() == {
@@ -129,23 +122,8 @@ def test_serialize_workflow():
129
122
  ignore_order=True,
130
123
  )
131
124
 
132
- mocked_base_nodes = [
133
- node
134
- for i, node in enumerate(workflow_raw_data["nodes"])
135
- if i != error_index and i != 0 and i != len(workflow_raw_data["nodes"]) - 1
136
- ]
137
-
138
- assert not DeepDiff(
139
- [
140
- {
141
- "type": "MOCKED",
142
- },
143
- {
144
- "type": "MOCKED",
145
- },
146
- ],
147
- mocked_base_nodes,
148
- )
125
+ passthrough_nodes = [node for node in workflow_raw_data["nodes"] if node["type"] == "GENERIC"]
126
+ assert len(passthrough_nodes) == 2
149
127
 
150
128
  terminal_node = workflow_raw_data["nodes"][-1]
151
129
  assert not DeepDiff(
@@ -0,0 +1,168 @@
1
+ from deepdiff import DeepDiff
2
+
3
+ from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
4
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
5
+
6
+ from tests.workflows.basic_generic_node.workflow import BasicGenericNodeWorkflow
7
+
8
+
9
+ def test_serialize_workflow(vellum_client):
10
+ # GIVEN a Workflow that uses a generic node
11
+ # WHEN we serialize it
12
+ workflow_display = get_workflow_display(
13
+ base_display_class=VellumWorkflowDisplay, workflow_class=BasicGenericNodeWorkflow
14
+ )
15
+
16
+ serialized_workflow: dict = workflow_display.serialize()
17
+
18
+ # THEN we should get a serialized representation of the Workflow
19
+ assert serialized_workflow.keys() == {
20
+ "workflow_raw_data",
21
+ "input_variables",
22
+ "output_variables",
23
+ }
24
+
25
+ # AND its input variables should be what we expect
26
+ input_variables = serialized_workflow["input_variables"]
27
+ assert len(input_variables) == 1
28
+ assert not DeepDiff(
29
+ [
30
+ {
31
+ "id": "a07c2273-34a7-42b5-bcad-143b6127cc8a",
32
+ "key": "input",
33
+ "type": "STRING",
34
+ "default": None,
35
+ "required": True,
36
+ "extensions": {"color": None},
37
+ },
38
+ ],
39
+ input_variables,
40
+ ignore_order=True,
41
+ )
42
+
43
+ # AND its output variables should be what we expect
44
+ output_variables = serialized_workflow["output_variables"]
45
+ assert len(output_variables) == 1
46
+ assert not DeepDiff(
47
+ [
48
+ {"id": "2b6389d0-266a-4be4-843e-4e543dd3d727", "key": "output", "type": "STRING"},
49
+ ],
50
+ output_variables,
51
+ ignore_order=True,
52
+ )
53
+
54
+ # AND its raw data should be what we expect
55
+ workflow_raw_data = serialized_workflow["workflow_raw_data"]
56
+ assert workflow_raw_data.keys() == {"edges", "nodes", "display_data", "definition"}
57
+ assert len(workflow_raw_data["edges"]) == 2
58
+ assert len(workflow_raw_data["nodes"]) == 3
59
+
60
+ # AND each node should be serialized correctly
61
+ entrypoint_node = workflow_raw_data["nodes"][0]
62
+ assert entrypoint_node == {
63
+ "id": "f1e4678f-c470-400b-a40e-c8922cc99a86",
64
+ "type": "ENTRYPOINT",
65
+ "inputs": [],
66
+ "data": {"label": "Entrypoint Node", "source_handle_id": "40201804-8beb-43ad-8873-a027759512f1"},
67
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
68
+ "definition": {
69
+ "name": "BaseNode",
70
+ "module": [
71
+ "vellum",
72
+ "workflows",
73
+ "nodes",
74
+ "bases",
75
+ "base",
76
+ ],
77
+ "bases": [],
78
+ },
79
+ }
80
+
81
+ api_node = workflow_raw_data["nodes"][1]
82
+ assert api_node["id"] == "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f"
83
+
84
+ final_output_node = workflow_raw_data["nodes"][2]
85
+ assert not DeepDiff(
86
+ {
87
+ "id": "50e3b446-afcd-4a5d-8c6f-5f05eaf2200e",
88
+ "type": "TERMINAL",
89
+ "data": {
90
+ "label": "Final Output",
91
+ "name": "output",
92
+ "target_handle_id": "8bd9f4f3-9f66-4d95-8e84-529b0002c531",
93
+ "output_id": "2b6389d0-266a-4be4-843e-4e543dd3d727",
94
+ "output_type": "STRING",
95
+ "node_input_id": "7a9f2d3a-0b23-4bd4-b567-e9493135b727",
96
+ },
97
+ "inputs": [
98
+ {
99
+ "id": "7a9f2d3a-0b23-4bd4-b567-e9493135b727",
100
+ "key": "node_input",
101
+ "value": {
102
+ "rules": [
103
+ {
104
+ "type": "NODE_OUTPUT",
105
+ "data": {
106
+ "node_id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
107
+ "output_id": "0a9c7a80-fc89-4a71-aac0-66489e4ddb85",
108
+ },
109
+ }
110
+ ],
111
+ "combinator": "OR",
112
+ },
113
+ }
114
+ ],
115
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
116
+ "definition": {
117
+ "name": "FinalOutputNode",
118
+ "module": ["vellum", "workflows", "nodes", "displayable", "final_output_node", "node"],
119
+ "bases": [
120
+ {"name": "BaseNode", "module": ["vellum", "workflows", "nodes", "bases", "base"], "bases": []}
121
+ ],
122
+ },
123
+ },
124
+ final_output_node,
125
+ ignore_order=True,
126
+ )
127
+
128
+ # AND each edge should be serialized correctly
129
+ serialized_edges = workflow_raw_data["edges"]
130
+ assert not DeepDiff(
131
+ [
132
+ {
133
+ "id": "445dd2de-82b2-482b-89f6-5f49d8eb21a9",
134
+ "source_node_id": "f1e4678f-c470-400b-a40e-c8922cc99a86",
135
+ "source_handle_id": "40201804-8beb-43ad-8873-a027759512f1",
136
+ "target_node_id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
137
+ "target_handle_id": "b7bfb298-959a-4d2b-8b85-bbd0d2522703",
138
+ "type": "DEFAULT",
139
+ },
140
+ {
141
+ "id": "b741c861-cf67-4649-b9ef-b43a4add72b1",
142
+ "source_node_id": "c2ed23f7-f6cb-4a56-a91c-2e5f9d8fda7f",
143
+ "source_handle_id": "89dccfa5-cc1a-4612-bd87-86cb444f6dd4",
144
+ "target_node_id": "50e3b446-afcd-4a5d-8c6f-5f05eaf2200e",
145
+ "target_handle_id": "8bd9f4f3-9f66-4d95-8e84-529b0002c531",
146
+ "type": "DEFAULT",
147
+ },
148
+ ],
149
+ serialized_edges,
150
+ ignore_order=True,
151
+ )
152
+
153
+ # AND the display data should be what we expect
154
+ display_data = workflow_raw_data["display_data"]
155
+ assert display_data == {
156
+ "viewport": {
157
+ "x": 0.0,
158
+ "y": 0.0,
159
+ "zoom": 1.0,
160
+ }
161
+ }
162
+
163
+ # AND the definition should be what we expect
164
+ definition = workflow_raw_data["definition"]
165
+ assert definition == {
166
+ "name": "BasicGenericNodeWorkflow",
167
+ "module": ["tests", "workflows", "basic_generic_node", "workflow"],
168
+ }