vellum-ai 0.13.28__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/prompts/blocks/compilation.py +23 -16
  3. vellum/prompts/blocks/tests/test_compilation.py +29 -0
  4. vellum/utils/templating/render.py +2 -0
  5. vellum/workflows/constants.py +8 -3
  6. vellum/workflows/descriptors/tests/test_utils.py +21 -0
  7. vellum/workflows/descriptors/utils.py +3 -3
  8. vellum/workflows/errors/types.py +4 -1
  9. vellum/workflows/expressions/coalesce_expression.py +2 -2
  10. vellum/workflows/expressions/contains.py +4 -3
  11. vellum/workflows/expressions/does_not_contain.py +2 -1
  12. vellum/workflows/expressions/is_nil.py +2 -2
  13. vellum/workflows/expressions/is_not_nil.py +2 -2
  14. vellum/workflows/expressions/is_not_undefined.py +2 -2
  15. vellum/workflows/expressions/is_undefined.py +2 -2
  16. vellum/workflows/nodes/bases/base.py +19 -3
  17. vellum/workflows/nodes/bases/tests/test_base_node.py +84 -0
  18. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +3 -3
  19. vellum/workflows/nodes/core/map_node/node.py +5 -0
  20. vellum/workflows/nodes/core/map_node/tests/test_node.py +22 -0
  21. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +39 -1
  22. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +68 -2
  23. vellum/workflows/nodes/displayable/code_execution_node/utils.py +30 -7
  24. vellum/workflows/nodes/utils.py +9 -1
  25. vellum/workflows/outputs/base.py +21 -19
  26. vellum/workflows/references/external_input.py +2 -2
  27. vellum/workflows/references/lazy.py +2 -2
  28. vellum/workflows/references/output.py +7 -7
  29. vellum/workflows/runner/runner.py +20 -15
  30. vellum/workflows/state/base.py +23 -3
  31. vellum/workflows/state/tests/test_state.py +7 -11
  32. vellum/workflows/workflows/base.py +20 -0
  33. vellum/workflows/workflows/tests/__init__.py +0 -0
  34. vellum/workflows/workflows/tests/test_base_workflow.py +80 -0
  35. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/METADATA +1 -1
  36. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/RECORD +67 -62
  37. vellum_ee/workflows/display/base.py +14 -0
  38. vellum_ee/workflows/display/nodes/base_node_display.py +13 -24
  39. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +52 -0
  40. vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +1 -0
  41. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -1
  42. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -0
  43. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -0
  44. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -0
  45. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +243 -0
  46. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -0
  47. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +1 -0
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +1 -1
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
  50. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -0
  51. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -0
  52. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +1 -1
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -0
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -0
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -0
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +1 -0
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +1 -0
  59. vellum_ee/workflows/display/types.py +5 -1
  60. vellum_ee/workflows/display/utils/expressions.py +26 -0
  61. vellum_ee/workflows/display/utils/vellum.py +5 -0
  62. vellum_ee/workflows/display/vellum.py +14 -0
  63. vellum_ee/workflows/display/workflows/base_workflow_display.py +30 -1
  64. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +41 -0
  65. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/LICENSE +0 -0
  66. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/WHEEL +0 -0
  67. {vellum_ai-0.13.28.dist-info → vellum_ai-0.14.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,243 @@
1
+ from deepdiff import DeepDiff
2
+
3
+ from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
4
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
5
+
6
+ from tests.workflows.basic_default_state.workflow import BasicDefaultStateWorkflow
7
+
8
+
9
+ def test_serialize_workflow():
10
+ # GIVEN a Workflow that has a simple state definition
11
+ # WHEN we serialize it
12
+ workflow_display = get_workflow_display(
13
+ base_display_class=VellumWorkflowDisplay, workflow_class=BasicDefaultStateWorkflow
14
+ )
15
+
16
+ serialized_workflow: dict = workflow_display.serialize()
17
+ # THEN we should get a serialized representation of the Workflow
18
+ assert serialized_workflow.keys() == {
19
+ "workflow_raw_data",
20
+ "input_variables",
21
+ "state_variables",
22
+ "output_variables",
23
+ }
24
+
25
+ # AND its input variables should be what we expect
26
+ input_variables = serialized_workflow["input_variables"]
27
+ assert len(input_variables) == 1
28
+ assert not DeepDiff(
29
+ [
30
+ {
31
+ "id": "0bbe085c-31d3-4b48-b74d-d501f592da90",
32
+ "key": "example",
33
+ "type": "STRING",
34
+ "default": {"type": "STRING", "value": "hello"},
35
+ "required": True,
36
+ "extensions": {"color": None},
37
+ },
38
+ ],
39
+ input_variables,
40
+ ignore_order=True,
41
+ )
42
+
43
+ # AND its state variables should be what we expect
44
+ state_variables = serialized_workflow["state_variables"]
45
+ assert len(state_variables) == 1
46
+ assert not DeepDiff(
47
+ [
48
+ {
49
+ "id": "812ec99b-1859-4361-b795-228628657bac",
50
+ "key": "example",
51
+ "type": "NUMBER",
52
+ "default": {"type": "NUMBER", "value": 5.0},
53
+ "required": True,
54
+ "extensions": {"color": None},
55
+ },
56
+ ],
57
+ state_variables,
58
+ ignore_order=True,
59
+ )
60
+
61
+ # AND its output variables should be what we expect
62
+ output_variables = serialized_workflow["output_variables"]
63
+ assert len(output_variables) == 2
64
+ assert not DeepDiff(
65
+ [
66
+ {
67
+ "id": "6e7eeaa5-9559-4ae3-8606-e52ead5805a5",
68
+ "key": "example_input",
69
+ "type": "STRING",
70
+ },
71
+ {
72
+ "id": "e3ae0fe3-7590-4eac-b808-45901d82f2ba",
73
+ "key": "example_state",
74
+ "type": "NUMBER",
75
+ },
76
+ ],
77
+ output_variables,
78
+ ignore_order=True,
79
+ )
80
+
81
+ # AND its raw data should be what we expect
82
+ workflow_raw_data = serialized_workflow["workflow_raw_data"]
83
+ assert workflow_raw_data.keys() == {"edges", "nodes", "display_data", "definition"}
84
+ assert len(workflow_raw_data["edges"]) == 3
85
+ assert len(workflow_raw_data["nodes"]) == 4
86
+
87
+ # AND each node should be serialized correctly
88
+ entrypoint_node = workflow_raw_data["nodes"][0]
89
+ assert entrypoint_node == {
90
+ "id": "32684932-7c7c-4b1c-aed2-553de29bf3f7",
91
+ "type": "ENTRYPOINT",
92
+ "inputs": [],
93
+ "data": {
94
+ "label": "Entrypoint Node",
95
+ "source_handle_id": "e4136ee4-a51a-4ca3-9a3a-aa96f5de2347",
96
+ },
97
+ "base": None,
98
+ "definition": None,
99
+ "display_data": {
100
+ "position": {"x": 0.0, "y": 0.0},
101
+ },
102
+ }
103
+
104
+ final_output_node_state = next(
105
+ n for n in workflow_raw_data["nodes"] if n["type"] == "TERMINAL" and n["data"]["name"] == "example_state"
106
+ )
107
+ final_output_node_input = next(
108
+ n for n in workflow_raw_data["nodes"] if n["type"] == "TERMINAL" and n["data"]["name"] == "example_input"
109
+ )
110
+ assert not DeepDiff(
111
+ {
112
+ "id": "27fdaa45-b8ce-464d-be50-cf71cc56bc10",
113
+ "type": "TERMINAL",
114
+ "data": {
115
+ "label": "Final Output",
116
+ "name": "example_state",
117
+ "target_handle_id": "e7a09eb2-c9fb-4d57-b436-9cd9384c8960",
118
+ "output_id": "e3ae0fe3-7590-4eac-b808-45901d82f2ba",
119
+ "output_type": "NUMBER",
120
+ "node_input_id": "8de6a408-cf76-4d04-9845-f75211b611be",
121
+ },
122
+ "inputs": [
123
+ {
124
+ "id": "8de6a408-cf76-4d04-9845-f75211b611be",
125
+ "key": "node_input",
126
+ "value": {
127
+ "rules": [
128
+ {
129
+ "type": "NODE_OUTPUT",
130
+ "data": {
131
+ "node_id": "1381c078-efa2-4255-89a1-7b4cb742c7fc",
132
+ "output_id": "84b59a1a-82bf-46bb-9826-9b393402d0fe",
133
+ },
134
+ }
135
+ ],
136
+ "combinator": "OR",
137
+ },
138
+ }
139
+ ],
140
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
141
+ "base": {
142
+ "name": "FinalOutputNode",
143
+ "module": ["vellum", "workflows", "nodes", "displayable", "final_output_node", "node"],
144
+ },
145
+ "definition": None,
146
+ },
147
+ final_output_node_state,
148
+ )
149
+
150
+ assert not DeepDiff(
151
+ {
152
+ "id": "ca8bb585-c9a8-4bf7-bf9d-534b600fe23b",
153
+ "type": "TERMINAL",
154
+ "data": {
155
+ "label": "Final Output",
156
+ "name": "example_input",
157
+ "target_handle_id": "8a4a7efd-0e18-43ed-ba32-803e22e3ba0a",
158
+ "output_id": "6e7eeaa5-9559-4ae3-8606-e52ead5805a5",
159
+ "output_type": "STRING",
160
+ "node_input_id": "796b4a0b-da10-403a-acc3-8ebd3ebd3667",
161
+ },
162
+ "inputs": [
163
+ {
164
+ "id": "796b4a0b-da10-403a-acc3-8ebd3ebd3667",
165
+ "key": "node_input",
166
+ "value": {
167
+ "rules": [
168
+ {
169
+ "type": "NODE_OUTPUT",
170
+ "data": {
171
+ "node_id": "1381c078-efa2-4255-89a1-7b4cb742c7fc",
172
+ "output_id": "f305c61e-6e8f-4cea-a53e-92656136b545",
173
+ },
174
+ }
175
+ ],
176
+ "combinator": "OR",
177
+ },
178
+ }
179
+ ],
180
+ "display_data": {"position": {"x": 0.0, "y": 0.0}},
181
+ "base": {
182
+ "name": "FinalOutputNode",
183
+ "module": ["vellum", "workflows", "nodes", "displayable", "final_output_node", "node"],
184
+ },
185
+ "definition": None,
186
+ },
187
+ final_output_node_input,
188
+ )
189
+
190
+ # AND each edge should be serialized correctly
191
+ serialized_edges = workflow_raw_data["edges"]
192
+ assert not DeepDiff(
193
+ [
194
+ {
195
+ "id": "26003bf1-b5a3-41f4-a419-c7ef5a7595df",
196
+ "source_node_id": "32684932-7c7c-4b1c-aed2-553de29bf3f7",
197
+ "source_handle_id": "e4136ee4-a51a-4ca3-9a3a-aa96f5de2347",
198
+ "target_node_id": "1381c078-efa2-4255-89a1-7b4cb742c7fc",
199
+ "target_handle_id": "a95a34f2-e894-4fb6-a2c9-15d12c1e3135",
200
+ "type": "DEFAULT",
201
+ },
202
+ {
203
+ "id": "b0a57a5f-a1e4-4dc9-85dd-946f08304738",
204
+ "source_node_id": "1381c078-efa2-4255-89a1-7b4cb742c7fc",
205
+ "source_handle_id": "1e739e86-a285-4438-9725-a152c15a63e3",
206
+ "target_node_id": "ca8bb585-c9a8-4bf7-bf9d-534b600fe23b",
207
+ "target_handle_id": "8a4a7efd-0e18-43ed-ba32-803e22e3ba0a",
208
+ "type": "DEFAULT",
209
+ },
210
+ {
211
+ "id": "e4366583-94a5-40b0-9b6f-1e965695b1fe",
212
+ "source_node_id": "1381c078-efa2-4255-89a1-7b4cb742c7fc",
213
+ "source_handle_id": "1e739e86-a285-4438-9725-a152c15a63e3",
214
+ "target_node_id": "27fdaa45-b8ce-464d-be50-cf71cc56bc10",
215
+ "target_handle_id": "e7a09eb2-c9fb-4d57-b436-9cd9384c8960",
216
+ "type": "DEFAULT",
217
+ },
218
+ ],
219
+ serialized_edges,
220
+ ignore_order=True,
221
+ )
222
+
223
+ # AND the display data should be what we expect
224
+ display_data = workflow_raw_data["display_data"]
225
+ assert display_data == {
226
+ "viewport": {
227
+ "x": 0.0,
228
+ "y": 0.0,
229
+ "zoom": 1.0,
230
+ }
231
+ }
232
+
233
+ # AND the definition should be what we expect
234
+ definition = workflow_raw_data["definition"]
235
+ assert definition == {
236
+ "name": "BasicDefaultStateWorkflow",
237
+ "module": [
238
+ "tests",
239
+ "workflows",
240
+ "basic_default_state",
241
+ "workflow",
242
+ ],
243
+ }
@@ -19,6 +19,7 @@ def test_serialize_workflow():
19
19
  assert serialized_workflow.keys() == {
20
20
  "workflow_raw_data",
21
21
  "input_variables",
22
+ "state_variables",
22
23
  "output_variables",
23
24
  }
24
25
 
@@ -19,6 +19,7 @@ def test_serialize_workflow(vellum_client):
19
19
  assert serialized_workflow.keys() == {
20
20
  "workflow_raw_data",
21
21
  "input_variables",
22
+ "state_variables",
22
23
  "output_variables",
23
24
  }
24
25
 
@@ -16,7 +16,7 @@ def test_serialize_workflow():
16
16
  serialized_workflow: dict = workflow_display.serialize()
17
17
 
18
18
  # THEN we should get a serialized representation of the workflow
19
- assert serialized_workflow.keys() == {"workflow_raw_data", "input_variables", "output_variables"}
19
+ assert serialized_workflow.keys() == {"workflow_raw_data", "input_variables", "state_variables", "output_variables"}
20
20
 
21
21
  # AND its input variables should be what we expect
22
22
  input_variables = serialized_workflow["input_variables"]
@@ -18,6 +18,7 @@ def test_serialize_workflow():
18
18
  assert serialized_workflow.keys() == {
19
19
  "workflow_raw_data",
20
20
  "input_variables",
21
+ "state_variables",
21
22
  "output_variables",
22
23
  }
23
24
 
@@ -16,6 +16,7 @@ def test_serialize_workflow():
16
16
  assert serialized_workflow.keys() == {
17
17
  "workflow_raw_data",
18
18
  "input_variables",
19
+ "state_variables",
19
20
  "output_variables",
20
21
  }
21
22
 
@@ -18,6 +18,7 @@ def test_serialize_workflow__await_all():
18
18
  assert serialized_workflow.keys() == {
19
19
  "workflow_raw_data",
20
20
  "input_variables",
21
+ "state_variables",
21
22
  "output_variables",
22
23
  }
23
24
 
@@ -34,6 +34,7 @@ def test_serialize_workflow(vellum_client):
34
34
  assert serialized_workflow.keys() == {
35
35
  "workflow_raw_data",
36
36
  "input_variables",
37
+ "state_variables",
37
38
  "output_variables",
38
39
  }
39
40
 
@@ -15,7 +15,7 @@ def test_serialize_workflow():
15
15
  serialized_workflow: dict = workflow_display.serialize()
16
16
 
17
17
  # THEN we should get a serialized representation of the workflow
18
- assert serialized_workflow.keys() == {"workflow_raw_data", "input_variables", "output_variables"}
18
+ assert serialized_workflow.keys() == {"workflow_raw_data", "input_variables", "state_variables", "output_variables"}
19
19
 
20
20
  # AND its input variables should be what we expect
21
21
  input_variables = serialized_workflow["input_variables"]
@@ -34,6 +34,7 @@ def test_serialize_workflow(vellum_client):
34
34
  assert serialized_workflow.keys() == {
35
35
  "workflow_raw_data",
36
36
  "input_variables",
37
+ "state_variables",
37
38
  "output_variables",
38
39
  }
39
40
 
@@ -20,6 +20,7 @@ def test_serialize_workflow():
20
20
  assert serialized_workflow.keys() == {
21
21
  "workflow_raw_data",
22
22
  "input_variables",
23
+ "state_variables",
23
24
  "output_variables",
24
25
  }
25
26
 
@@ -16,6 +16,7 @@ def test_serialize_workflow():
16
16
  assert serialized_workflow.keys() == {
17
17
  "workflow_raw_data",
18
18
  "input_variables",
19
+ "state_variables",
19
20
  "output_variables",
20
21
  }
21
22
 
@@ -16,6 +16,7 @@ def test_serialize_workflow():
16
16
  assert serialized_workflow.keys() == {
17
17
  "workflow_raw_data",
18
18
  "input_variables",
19
+ "state_variables",
19
20
  "output_variables",
20
21
  }
21
22
 
@@ -22,6 +22,7 @@ def test_serialize_workflow__missing_final_output_node():
22
22
  assert serialized_workflow.keys() == {
23
23
  "workflow_raw_data",
24
24
  "input_variables",
25
+ "state_variables",
25
26
  "output_variables",
26
27
  }
27
28
 
@@ -6,10 +6,11 @@ from vellum.client.core import UniversalBaseModel
6
6
  from vellum.workflows.descriptors.base import BaseDescriptor
7
7
  from vellum.workflows.nodes import BaseNode
8
8
  from vellum.workflows.ports import Port
9
- from vellum.workflows.references import OutputReference, WorkflowInputReference
9
+ from vellum.workflows.references import OutputReference, StateValueReference, WorkflowInputReference
10
10
  from vellum_ee.workflows.display.base import (
11
11
  EdgeDisplayType,
12
12
  EntrypointDisplayType,
13
+ StateValueDisplayType,
13
14
  WorkflowInputsDisplayType,
14
15
  WorkflowMetaDisplayType,
15
16
  WorkflowOutputDisplayType,
@@ -41,6 +42,7 @@ class WorkflowDisplayContext(
41
42
  Generic[
42
43
  WorkflowMetaDisplayType,
43
44
  WorkflowInputsDisplayType,
45
+ StateValueDisplayType,
44
46
  NodeDisplayType,
45
47
  EntrypointDisplayType,
46
48
  WorkflowOutputDisplayType,
@@ -53,6 +55,8 @@ class WorkflowDisplayContext(
53
55
  global_workflow_input_displays: Dict[WorkflowInputReference, WorkflowInputsDisplayType] = field(
54
56
  default_factory=dict
55
57
  )
58
+ state_value_displays: Dict[StateValueReference, StateValueDisplayType] = field(default_factory=dict)
59
+ global_state_value_displays: Dict[StateValueReference, StateValueDisplayType] = field(default_factory=dict)
56
60
  node_displays: Dict[Type[BaseNode], "NodeDisplayType"] = field(default_factory=dict)
57
61
  global_node_displays: Dict[Type[BaseNode], NodeDisplayType] = field(default_factory=dict)
58
62
  global_node_output_displays: Dict[OutputReference, Tuple[Type[BaseNode], "NodeOutputDisplay"]] = field(
@@ -0,0 +1,26 @@
1
+ from vellum.workflows.descriptors.base import BaseDescriptor
2
+ from vellum.workflows.references.lazy import LazyReference
3
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
4
+
5
+
6
+ def get_child_descriptor(value: LazyReference, display_context: WorkflowDisplayContext) -> BaseDescriptor:
7
+ if isinstance(value._get, str):
8
+ reference_parts = value._get.split(".")
9
+ if len(reference_parts) < 3:
10
+ raise Exception(f"Failed to parse lazy reference: {value._get}. Only Node Output references are supported.")
11
+
12
+ output_name = reference_parts[-1]
13
+ nested_class_name = reference_parts[-2]
14
+ if nested_class_name != "Outputs":
15
+ raise Exception(
16
+ f"Failed to parse lazy reference: {value._get}. Outputs are the only node reference supported."
17
+ )
18
+
19
+ node_class_name = ".".join(reference_parts[:-2])
20
+ for node in display_context.node_displays.keys():
21
+ if node.__name__ == node_class_name:
22
+ return getattr(node.Outputs, output_name)
23
+
24
+ raise Exception(f"Failed to parse lazy reference: {value._get}")
25
+
26
+ return value._get()
@@ -32,10 +32,12 @@ from vellum.workflows.nodes.bases.base import BaseNode
32
32
  from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
33
33
  from vellum.workflows.references import OutputReference, WorkflowInputReference
34
34
  from vellum.workflows.references.execution_count import ExecutionCountReference
35
+ from vellum.workflows.references.lazy import LazyReference
35
36
  from vellum.workflows.references.node import NodeReference
36
37
  from vellum.workflows.references.vellum_secret import VellumSecretReference
37
38
  from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_variable_type
38
39
  from vellum.workflows.vellum_client import create_vellum_client
40
+ from vellum_ee.workflows.display.utils.expressions import get_child_descriptor
39
41
  from vellum_ee.workflows.display.vellum import (
40
42
  ConstantValuePointer,
41
43
  ExecutionCounterData,
@@ -88,6 +90,9 @@ def create_node_input_value_pointer_rule(
88
90
  return NodeOutputPointer(
89
91
  data=NodeOutputData(node_id=str(upstream_node_display.node_id), output_id=str(output_display.id)),
90
92
  )
93
+ if isinstance(value, LazyReference):
94
+ child_descriptor = get_child_descriptor(value, display_context)
95
+ return create_node_input_value_pointer_rule(child_descriptor, display_context)
91
96
  if isinstance(value, WorkflowInputReference):
92
97
  workflow_input_display = display_context.global_workflow_input_displays[value]
93
98
  return InputVariablePointer(data=InputVariableData(input_variable_id=str(workflow_input_display.id)))
@@ -14,6 +14,8 @@ from vellum_ee.workflows.display.base import (
14
14
  EdgeDisplayOverrides,
15
15
  EntrypointDisplay,
16
16
  EntrypointDisplayOverrides,
17
+ StateValueDisplay,
18
+ StateValueDisplayOverrides,
17
19
  WorkflowInputsDisplay,
18
20
  WorkflowInputsDisplayOverrides,
19
21
  WorkflowMetaDisplay,
@@ -84,6 +86,18 @@ class WorkflowInputsVellumDisplay(WorkflowInputsVellumDisplayOverrides):
84
86
  pass
85
87
 
86
88
 
89
+ @dataclass
90
+ class StateValueVellumDisplayOverrides(StateValueDisplay, StateValueDisplayOverrides):
91
+ name: Optional[str] = None
92
+ required: Optional[bool] = None
93
+ color: Optional[str] = None
94
+
95
+
96
+ @dataclass
97
+ class StateValueVellumDisplay(StateValueVellumDisplayOverrides):
98
+ pass
99
+
100
+
87
101
  @dataclass
88
102
  class EdgeVellumDisplayOverrides(EdgeDisplay, EdgeDisplayOverrides):
89
103
  pass
@@ -13,7 +13,7 @@ from vellum.workflows.expressions.coalesce_expression import CoalesceExpression
13
13
  from vellum.workflows.nodes.bases import BaseNode
14
14
  from vellum.workflows.nodes.utils import get_wrapped_node
15
15
  from vellum.workflows.ports import Port
16
- from vellum.workflows.references import OutputReference, WorkflowInputReference
16
+ from vellum.workflows.references import OutputReference, StateValueReference, WorkflowInputReference
17
17
  from vellum.workflows.types.core import JsonObject
18
18
  from vellum.workflows.types.generics import WorkflowType
19
19
  from vellum.workflows.utils.uuids import uuid4_from_hash
@@ -22,6 +22,8 @@ from vellum_ee.workflows.display.base import (
22
22
  EdgeDisplayType,
23
23
  EntrypointDisplayOverridesType,
24
24
  EntrypointDisplayType,
25
+ StateValueDisplayOverridesType,
26
+ StateValueDisplayType,
25
27
  WorkflowInputsDisplayOverridesType,
26
28
  WorkflowInputsDisplayType,
27
29
  WorkflowMetaDisplayOverridesType,
@@ -49,6 +51,8 @@ class BaseWorkflowDisplay(
49
51
  WorkflowMetaDisplayOverridesType,
50
52
  WorkflowInputsDisplayType,
51
53
  WorkflowInputsDisplayOverridesType,
54
+ StateValueDisplayType,
55
+ StateValueDisplayOverridesType,
52
56
  NodeDisplayType,
53
57
  EntrypointDisplayType,
54
58
  EntrypointDisplayOverridesType,
@@ -64,6 +68,9 @@ class BaseWorkflowDisplay(
64
68
  # Used to explicitly specify display data for a workflow's inputs.
65
69
  inputs_display: Dict[WorkflowInputReference, WorkflowInputsDisplayOverridesType] = {}
66
70
 
71
+ # Used to explicitly specify display data for a workflow's state values.
72
+ state_value_displays: Dict[StateValueReference, StateValueDisplayOverridesType] = {}
73
+
67
74
  # Used to explicitly specify display data for a workflow's entrypoints.
68
75
  entrypoint_displays: Dict[Type[BaseNode], EntrypointDisplayOverridesType] = {}
69
76
 
@@ -91,6 +98,7 @@ class BaseWorkflowDisplay(
91
98
  WorkflowDisplayContext[
92
99
  WorkflowMetaDisplayType,
93
100
  WorkflowInputsDisplayType,
101
+ StateValueDisplayType,
94
102
  NodeDisplayType,
95
103
  EntrypointDisplayType,
96
104
  WorkflowOutputDisplayType,
@@ -191,6 +199,7 @@ class BaseWorkflowDisplay(
191
199
  ) -> WorkflowDisplayContext[
192
200
  WorkflowMetaDisplayType,
193
201
  WorkflowInputsDisplayType,
202
+ StateValueDisplayType,
194
203
  NodeDisplayType,
195
204
  EntrypointDisplayType,
196
205
  WorkflowOutputDisplayType,
@@ -244,6 +253,18 @@ class BaseWorkflowDisplay(
244
253
  workflow_input_displays[workflow_input] = input_display
245
254
  global_workflow_input_displays[workflow_input] = input_display
246
255
 
256
+ state_value_displays: Dict[StateValueReference, StateValueDisplayType] = {}
257
+ global_state_value_displays = (
258
+ copy(self._parent_display_context.global_state_value_displays) if self._parent_display_context else {}
259
+ )
260
+ for state_value in self._workflow.get_state_class():
261
+ state_value_display_overrides = self.state_value_displays.get(state_value)
262
+ state_value_display = self._generate_state_value_display(
263
+ state_value, overrides=state_value_display_overrides
264
+ )
265
+ state_value_displays[state_value] = state_value_display
266
+ global_state_value_displays[state_value] = state_value_display
267
+
247
268
  entrypoint_displays: Dict[Type[BaseNode], EntrypointDisplayType] = {}
248
269
  for entrypoint in self._workflow.get_entrypoints():
249
270
  if entrypoint in entrypoint_displays:
@@ -286,6 +307,8 @@ class BaseWorkflowDisplay(
286
307
  workflow_display=workflow_display,
287
308
  workflow_input_displays=workflow_input_displays,
288
309
  global_workflow_input_displays=global_workflow_input_displays,
310
+ state_value_displays=state_value_displays,
311
+ global_state_value_displays=global_state_value_displays,
289
312
  node_displays=node_displays,
290
313
  global_node_output_displays=global_node_output_displays,
291
314
  global_node_displays=global_node_displays,
@@ -306,6 +329,12 @@ class BaseWorkflowDisplay(
306
329
  ) -> WorkflowInputsDisplayType:
307
330
  pass
308
331
 
332
+ @abstractmethod
333
+ def _generate_state_value_display(
334
+ self, state_value: StateValueReference, overrides: Optional[StateValueDisplayOverridesType] = None
335
+ ) -> StateValueDisplayType:
336
+ pass
337
+
309
338
  @abstractmethod
310
339
  def _generate_entrypoint_display(
311
340
  self,
@@ -25,6 +25,8 @@ from vellum_ee.workflows.display.vellum import (
25
25
  EntrypointVellumDisplay,
26
26
  EntrypointVellumDisplayOverrides,
27
27
  NodeDisplayData,
28
+ StateValueVellumDisplay,
29
+ StateValueVellumDisplayOverrides,
28
30
  WorkflowInputsVellumDisplay,
29
31
  WorkflowInputsVellumDisplayOverrides,
30
32
  WorkflowMetaVellumDisplay,
@@ -44,6 +46,8 @@ class VellumWorkflowDisplay(
44
46
  WorkflowMetaVellumDisplayOverrides,
45
47
  WorkflowInputsVellumDisplay,
46
48
  WorkflowInputsVellumDisplayOverrides,
49
+ StateValueVellumDisplay,
50
+ StateValueVellumDisplayOverrides,
47
51
  BaseNodeDisplay,
48
52
  EntrypointVellumDisplay,
49
53
  EntrypointVellumDisplayOverrides,
@@ -76,6 +80,25 @@ class VellumWorkflowDisplay(
76
80
  }
77
81
  )
78
82
 
83
+ state_variables: JsonArray = []
84
+ for state_value, state_value_display in self.display_context.state_value_displays.items():
85
+ default = primitive_to_vellum_value(state_value.instance) if state_value.instance else None
86
+ required = (
87
+ state_value_display.required
88
+ if state_value_display.required is not None
89
+ else type(None) not in state_value.types
90
+ )
91
+ state_variables.append(
92
+ {
93
+ "id": str(state_value_display.id),
94
+ "key": state_value_display.name or state_value.name,
95
+ "type": infer_vellum_variable_type(state_value),
96
+ "default": default.dict() if default else None,
97
+ "required": required,
98
+ "extensions": {"color": state_value_display.color},
99
+ }
100
+ )
101
+
79
102
  nodes: JsonArray = []
80
103
  edges: JsonArray = []
81
104
 
@@ -244,6 +267,7 @@ class VellumWorkflowDisplay(
244
267
  },
245
268
  },
246
269
  "input_variables": input_variables,
270
+ "state_variables": state_variables,
247
271
  "output_variables": output_variables,
248
272
  }
249
273
 
@@ -283,6 +307,23 @@ class VellumWorkflowDisplay(
283
307
 
284
308
  return WorkflowInputsVellumDisplay(id=workflow_input_id, name=name, required=required, color=color)
285
309
 
310
+ def _generate_state_value_display(
311
+ self, state_value: BaseDescriptor, overrides: Optional[StateValueVellumDisplayOverrides] = None
312
+ ) -> StateValueVellumDisplay:
313
+ state_value_id: UUID
314
+ name = None
315
+ required = None
316
+ color = None
317
+ if overrides:
318
+ state_value_id = overrides.id
319
+ name = overrides.name
320
+ required = overrides.required
321
+ color = overrides.color
322
+ else:
323
+ state_value_id = uuid4_from_hash(f"{self.workflow_id}|state_values|id|{state_value.name}")
324
+
325
+ return StateValueVellumDisplay(id=state_value_id, name=name, required=required, color=color)
326
+
286
327
  def _generate_entrypoint_display(
287
328
  self,
288
329
  entrypoint: Type[BaseNode],