vellum-ai 0.10.8__py3-none-any.whl → 0.10.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,203 @@
1
+ from unittest import mock
2
+
3
+ from deepdiff import DeepDiff
4
+
5
+ from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
6
+ from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
7
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
8
+
9
+ from tests.workflows.basic_error_node.workflow import BasicErrorNodeWorkflow
10
+
11
+
12
+ def test_serialize_workflow():
13
+ # GIVEN a Workflow with an error node
14
+ # WHEN we serialize it
15
+ workflow_display = get_workflow_display(
16
+ base_display_class=VellumWorkflowDisplay, workflow_class=BasicErrorNodeWorkflow
17
+ )
18
+
19
+ # TODO: Support serialization of BaseNode
20
+ # https://app.shortcut.com/vellum/story/4871/support-serialization-of-base-node
21
+ with mock.patch.object(BaseNodeVellumDisplay, "serialize") as mocked_serialize:
22
+ mocked_serialize.return_value = {"type": "MOCKED"}
23
+ serialized_workflow: dict = workflow_display.serialize()
24
+
25
+ # THEN we should get a serialized representation of the Workflow
26
+ assert serialized_workflow.keys() == {
27
+ "workflow_raw_data",
28
+ "input_variables",
29
+ "output_variables",
30
+ }
31
+
32
+ # AND its input variables should be what we expect
33
+ input_variables = serialized_workflow["input_variables"]
34
+ assert len(input_variables) == 1
35
+ assert not DeepDiff(
36
+ [
37
+ {
38
+ "id": "5d9edd44-b35b-4bad-ad51-ccdfe8185ff5",
39
+ "key": "threshold",
40
+ "type": "NUMBER",
41
+ "default": None,
42
+ "required": True,
43
+ "extensions": {"color": None},
44
+ }
45
+ ],
46
+ input_variables,
47
+ ignore_order=True,
48
+ )
49
+
50
+ # AND its output variables should be what we expect
51
+ output_variables = serialized_workflow["output_variables"]
52
+ assert len(output_variables) == 1
53
+ assert not DeepDiff(
54
+ [
55
+ {
56
+ "id": "04c5c6be-f5e1-41b8-b668-39e179790d9e",
57
+ "key": "final_value",
58
+ "type": "NUMBER",
59
+ }
60
+ ],
61
+ output_variables,
62
+ ignore_order=True,
63
+ )
64
+
65
+ # AND its raw data should be what we expect
66
+ workflow_raw_data = serialized_workflow["workflow_raw_data"]
67
+ assert workflow_raw_data.keys() == {"edges", "nodes", "display_data", "definition"}
68
+ assert len(workflow_raw_data["edges"]) == 4
69
+ assert len(workflow_raw_data["nodes"]) == 5
70
+
71
+ # AND each node should be serialized correctly
72
+ entrypoint_node = workflow_raw_data["nodes"][0]
73
+ assert entrypoint_node == {
74
+ "id": "10e90662-e998-421d-a5c9-ec16e37a8de1",
75
+ "type": "ENTRYPOINT",
76
+ "inputs": [],
77
+ "data": {
78
+ "label": "Entrypoint Node",
79
+ "source_handle_id": "7d86498b-84ed-4feb-8e62-2188058c2c4e",
80
+ },
81
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
82
+ "definition": {
83
+ "name": "BaseNode",
84
+ "module": ["vellum", "workflows", "nodes", "bases", "base"],
85
+ "bases": [],
86
+ },
87
+ }
88
+
89
+ error_node, error_index = next(
90
+ (
91
+ (node, index)
92
+ for index, node in enumerate(workflow_raw_data["nodes"])
93
+ if node.get("data", {}).get("label") == "Fail Node"
94
+ ),
95
+ (None, None),
96
+ )
97
+ assert not DeepDiff(
98
+ {
99
+ "id": "5cf9c5e3-0eae-4daf-8d73-8b9536258eb9",
100
+ "type": "ERROR",
101
+ "inputs": [],
102
+ "data": {
103
+ "name": "error-node",
104
+ "label": "Fail Node",
105
+ "source_handle_id": "ca17d318-a0f5-4f7c-be6c-59c9dc1dd7ed",
106
+ "target_handle_id": "70c19f1c-309c-4a5d-ba65-664c0bb2fedf",
107
+ "error_source_input_id": "None",
108
+ "error_output_id": "None",
109
+ },
110
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
111
+ "definition": {
112
+ "name": "FailNode",
113
+ "module": ["tests", "workflows", "basic_error_node", "workflow"],
114
+ "bases": [
115
+ {
116
+ "name": "ErrorNode",
117
+ "module": [
118
+ "vellum",
119
+ "workflows",
120
+ "nodes",
121
+ "core",
122
+ "error_node",
123
+ "node",
124
+ ],
125
+ }
126
+ ],
127
+ },
128
+ },
129
+ error_node,
130
+ ignore_order=True,
131
+ )
132
+
133
+ mocked_base_nodes = [
134
+ node
135
+ for i, node in enumerate(workflow_raw_data["nodes"])
136
+ if i != error_index and i != 0 and i != len(workflow_raw_data["nodes"]) - 1
137
+ ]
138
+
139
+ assert not DeepDiff(
140
+ [
141
+ {
142
+ "type": "MOCKED",
143
+ },
144
+ {
145
+ "type": "MOCKED",
146
+ },
147
+ ],
148
+ mocked_base_nodes,
149
+ )
150
+
151
+ terminal_node = workflow_raw_data["nodes"][-1]
152
+ assert not DeepDiff(
153
+ {
154
+ "id": "e5fff999-80c7-4cbc-9d99-06c653f3ec77",
155
+ "type": "TERMINAL",
156
+ "data": {
157
+ "label": "Final Output",
158
+ "name": "final_value",
159
+ "target_handle_id": "b070e9bc-e9b7-46d3-8f5b-0b646bd25cf0",
160
+ "output_id": "04c5c6be-f5e1-41b8-b668-39e179790d9e",
161
+ "output_type": "NUMBER",
162
+ "node_input_id": "39ff42c9-eae8-432e-ad41-e208fba77027",
163
+ },
164
+ "inputs": [
165
+ {
166
+ "id": "39ff42c9-eae8-432e-ad41-e208fba77027",
167
+ "key": "node_input",
168
+ "value": {
169
+ "rules": [
170
+ {
171
+ "type": "NODE_OUTPUT",
172
+ "data": {
173
+ "node_id": "1eee9b4e-531f-45f2-a4b9-42207fac2c33",
174
+ "output_id": "c6b017a4-25e9-4296-8d81-6aa4b3dad171",
175
+ },
176
+ }
177
+ ],
178
+ "combinator": "OR",
179
+ },
180
+ }
181
+ ],
182
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
183
+ "definition": {
184
+ "name": "FinalOutputNode",
185
+ "module": [
186
+ "vellum",
187
+ "workflows",
188
+ "nodes",
189
+ "displayable",
190
+ "final_output_node",
191
+ "node",
192
+ ],
193
+ "bases": [
194
+ {
195
+ "name": "BaseNode",
196
+ "module": ["vellum", "workflows", "nodes", "bases", "base"],
197
+ "bases": [],
198
+ }
199
+ ],
200
+ },
201
+ },
202
+ terminal_node,
203
+ )
@@ -1,5 +0,0 @@
1
- from vellum.workflows.events.workflow import WorkflowEvent
2
-
3
-
4
- def is_terminal_event(event: WorkflowEvent) -> bool:
5
- return event.name in {"workflow.execution.fulfilled", "workflow.execution.rejected", "workflow.execution.paused"}