vellum-ai 0.12.17__py3-none-any.whl → 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (83) hide show
  1. vellum/__init__.py +10 -22
  2. vellum/client/__init__.py +8 -0
  3. vellum/client/core/client_wrapper.py +1 -1
  4. vellum/client/resources/__init__.py +4 -0
  5. vellum/client/resources/organizations/__init__.py +2 -0
  6. vellum/client/resources/organizations/client.py +116 -0
  7. vellum/client/resources/workspaces/__init__.py +2 -0
  8. vellum/client/resources/workspaces/client.py +114 -0
  9. vellum/client/types/__init__.py +6 -22
  10. vellum/client/types/new_member_join_behavior_enum.py +8 -0
  11. vellum/client/types/{function_call_variable_value.py → organization_read.py} +6 -4
  12. vellum/client/types/workflow_execution_workflow_result_event.py +0 -2
  13. vellum/client/types/workflow_result_event.py +0 -2
  14. vellum/client/types/workflow_result_event_output_data_array.py +4 -4
  15. vellum/client/types/{string_variable_value.py → workspace_read.py} +12 -5
  16. vellum/{types/json_variable_value.py → resources/organizations/__init__.py} +1 -1
  17. vellum/resources/organizations/client.py +3 -0
  18. vellum/{types/image_variable_value.py → resources/workspaces/__init__.py} +1 -1
  19. vellum/{types/array_variable_value.py → resources/workspaces/client.py} +1 -1
  20. vellum/types/{array_variable_value_item.py → new_member_join_behavior_enum.py} +1 -1
  21. vellum/types/{audio_variable_value.py → organization_read.py} +1 -1
  22. vellum/types/{error_variable_value.py → workspace_read.py} +1 -1
  23. vellum/workflows/workflows/base.py +0 -32
  24. {vellum_ai-0.12.17.dist-info → vellum_ai-0.13.0.dist-info}/METADATA +1 -1
  25. {vellum_ai-0.12.17.dist-info → vellum_ai-0.13.0.dist-info}/RECORD +69 -76
  26. vellum_ee/workflows/display/nodes/base_node_display.py +17 -10
  27. vellum_ee/workflows/display/nodes/vellum/api_node.py +1 -0
  28. vellum_ee/workflows/display/nodes/vellum/base_node.py +97 -2
  29. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -0
  30. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +5 -62
  31. vellum_ee/workflows/display/nodes/vellum/error_node.py +1 -0
  32. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +1 -0
  33. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +1 -0
  34. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +1 -0
  35. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +1 -0
  36. vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -0
  37. vellum_ee/workflows/display/nodes/vellum/merge_node.py +1 -0
  38. vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -0
  39. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +1 -0
  40. vellum_ee/workflows/display/nodes/vellum/search_node.py +1 -0
  41. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +1 -0
  42. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -0
  43. vellum_ee/workflows/display/nodes/vellum/utils.py +63 -0
  44. vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +2 -5
  45. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +18 -2
  46. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +355 -0
  47. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_trigger_serialization.py +37 -22
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +12 -56
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +43 -93
  50. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +31 -151
  51. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +8 -26
  52. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +4 -15
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +9 -44
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +19 -101
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +19 -73
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +9 -44
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +9 -44
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +8 -6
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +11 -58
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +8 -11
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +7 -30
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -11
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +9 -44
  64. vellum_ee/workflows/display/vellum.py +2 -7
  65. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +5 -9
  66. vellum_ee/workflows/server/virtual_file_loader.py +3 -3
  67. vellum/client/types/array_variable_value.py +0 -27
  68. vellum/client/types/array_variable_value_item.py +0 -29
  69. vellum/client/types/audio_variable_value.py +0 -25
  70. vellum/client/types/chat_history_variable_value.py +0 -21
  71. vellum/client/types/error_variable_value.py +0 -21
  72. vellum/client/types/image_variable_value.py +0 -25
  73. vellum/client/types/json_variable_value.py +0 -20
  74. vellum/client/types/number_variable_value.py +0 -20
  75. vellum/client/types/search_results_variable_value.py +0 -21
  76. vellum/types/chat_history_variable_value.py +0 -3
  77. vellum/types/function_call_variable_value.py +0 -3
  78. vellum/types/number_variable_value.py +0 -3
  79. vellum/types/search_results_variable_value.py +0 -3
  80. vellum/types/string_variable_value.py +0 -3
  81. {vellum_ai-0.12.17.dist-info → vellum_ai-0.13.0.dist-info}/LICENSE +0 -0
  82. {vellum_ai-0.12.17.dist-info → vellum_ai-0.13.0.dist-info}/WHEEL +0 -0
  83. {vellum_ai-0.12.17.dist-info → vellum_ai-0.13.0.dist-info}/entry_points.txt +0 -0
@@ -4,34 +4,16 @@ from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, TypeVar,
4
4
 
5
5
  from vellum.workflows.descriptors.base import BaseDescriptor
6
6
  from vellum.workflows.expressions.and_ import AndExpression
7
- from vellum.workflows.expressions.begins_with import BeginsWithExpression
8
7
  from vellum.workflows.expressions.between import BetweenExpression
9
- from vellum.workflows.expressions.contains import ContainsExpression
10
- from vellum.workflows.expressions.does_not_begin_with import DoesNotBeginWithExpression
11
- from vellum.workflows.expressions.does_not_contain import DoesNotContainExpression
12
- from vellum.workflows.expressions.does_not_end_with import DoesNotEndWithExpression
13
- from vellum.workflows.expressions.does_not_equal import DoesNotEqualExpression
14
- from vellum.workflows.expressions.ends_with import EndsWithExpression
15
- from vellum.workflows.expressions.equals import EqualsExpression
16
- from vellum.workflows.expressions.greater_than import GreaterThanExpression
17
- from vellum.workflows.expressions.greater_than_or_equal_to import GreaterThanOrEqualToExpression
18
- from vellum.workflows.expressions.in_ import InExpression
19
- from vellum.workflows.expressions.is_nil import IsNilExpression
20
- from vellum.workflows.expressions.is_not_nil import IsNotNilExpression
21
8
  from vellum.workflows.expressions.is_not_null import IsNotNullExpression
22
- from vellum.workflows.expressions.is_not_undefined import IsNotUndefinedExpression
23
9
  from vellum.workflows.expressions.is_null import IsNullExpression
24
- from vellum.workflows.expressions.is_undefined import IsUndefinedExpression
25
- from vellum.workflows.expressions.less_than import LessThanExpression
26
- from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
27
10
  from vellum.workflows.expressions.not_between import NotBetweenExpression
28
- from vellum.workflows.expressions.not_in import NotInExpression
29
11
  from vellum.workflows.expressions.or_ import OrExpression
30
12
  from vellum.workflows.nodes.displayable import ConditionalNode
31
13
  from vellum.workflows.types.core import ConditionType, JsonObject
32
14
  from vellum.workflows.utils.uuids import uuid4_from_hash
33
15
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
34
- from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
16
+ from vellum_ee.workflows.display.nodes.vellum.utils import convert_descriptor_to_operator, create_node_input
35
17
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
36
18
  from vellum_ee.workflows.display.vellum import NodeInput
37
19
 
@@ -113,7 +95,7 @@ but the defined conditions have length {len(condition_ids)}"""
113
95
  )
114
96
  node_inputs.append(expression_node_input)
115
97
  field_node_input_id = expression_node_input.id
116
- operator = self._convert_descriptor_to_operator(descriptor)
98
+ operator = convert_descriptor_to_operator(descriptor)
117
99
 
118
100
  elif isinstance(descriptor, (BetweenExpression, NotBetweenExpression)):
119
101
  field_node_input = create_node_input(
@@ -127,7 +109,7 @@ but the defined conditions have length {len(condition_ids)}"""
127
109
  value_node_input_id,
128
110
  )
129
111
  node_inputs.extend([field_node_input, value_node_input])
130
- operator = self._convert_descriptor_to_operator(descriptor)
112
+ operator = convert_descriptor_to_operator(descriptor)
131
113
  field_node_input_id = field_node_input.id
132
114
  value_node_input_id = value_node_input.id
133
115
 
@@ -147,7 +129,7 @@ but the defined conditions have length {len(condition_ids)}"""
147
129
  node_inputs.append(rhs_node_input)
148
130
  value_node_input_id = rhs_node_input.id
149
131
 
150
- operator = self._convert_descriptor_to_operator(descriptor)
132
+ operator = convert_descriptor_to_operator(descriptor)
151
133
  field_node_input_id = lhs_node_input.id
152
134
 
153
135
  return {
@@ -218,49 +200,10 @@ but the defined conditions have length {len(condition_ids)}"""
218
200
  "version": "2",
219
201
  },
220
202
  "display_data": self.get_display_data().dict(),
203
+ "base": self.get_base().dict(),
221
204
  "definition": self.get_definition().dict(),
222
205
  }
223
206
 
224
- def _convert_descriptor_to_operator(self, descriptor: BaseDescriptor) -> str:
225
- if isinstance(descriptor, EqualsExpression):
226
- return "="
227
- elif isinstance(descriptor, DoesNotEqualExpression):
228
- return "!="
229
- elif isinstance(descriptor, LessThanExpression):
230
- return "<"
231
- elif isinstance(descriptor, GreaterThanExpression):
232
- return ">"
233
- elif isinstance(descriptor, LessThanOrEqualToExpression):
234
- return "<="
235
- elif isinstance(descriptor, GreaterThanOrEqualToExpression):
236
- return ">="
237
- elif isinstance(descriptor, ContainsExpression):
238
- return "contains"
239
- elif isinstance(descriptor, BeginsWithExpression):
240
- return "beginsWith"
241
- elif isinstance(descriptor, EndsWithExpression):
242
- return "endsWith"
243
- elif isinstance(descriptor, DoesNotContainExpression):
244
- return "doesNotContain"
245
- elif isinstance(descriptor, DoesNotBeginWithExpression):
246
- return "doesNotBeginWith"
247
- elif isinstance(descriptor, DoesNotEndWithExpression):
248
- return "doesNotEndWith"
249
- elif isinstance(descriptor, (IsNullExpression, IsNilExpression, IsUndefinedExpression)):
250
- return "null"
251
- elif isinstance(descriptor, (IsNotNullExpression, IsNotNilExpression, IsNotUndefinedExpression)):
252
- return "notNull"
253
- elif isinstance(descriptor, InExpression):
254
- return "in"
255
- elif isinstance(descriptor, NotInExpression):
256
- return "notIn"
257
- elif isinstance(descriptor, BetweenExpression):
258
- return "between"
259
- elif isinstance(descriptor, NotBetweenExpression):
260
- return "notBetween"
261
- else:
262
- raise ValueError(f"Unsupported descriptor type: {descriptor}")
263
-
264
207
  def get_nested_rule_details_by_path(
265
208
  self, rule_ids: List[RuleIdMap], path: List[int]
266
209
  ) -> Union[Tuple[str, Optional[str], Optional[str]], None]:
@@ -44,5 +44,6 @@ class BaseErrorNodeDisplay(BaseNodeVellumDisplay[_ErrorNodeType], Generic[_Error
44
44
  "error_output_id": str(self.error_output_id),
45
45
  },
46
46
  "display_data": self.get_display_data().dict(),
47
+ "base": self.get_base().dict(),
47
48
  "definition": self.get_definition().dict(),
48
49
  }
@@ -45,6 +45,7 @@ class BaseFinalOutputNodeDisplay(BaseNodeVellumDisplay[_FinalOutputNodeType], Ge
45
45
  },
46
46
  "inputs": [node_input.dict()],
47
47
  "display_data": self.get_display_data().dict(),
48
+ "base": self.get_base().dict(),
48
49
  "definition": self.get_definition().dict(),
49
50
  }
50
51
 
@@ -45,5 +45,6 @@ class BaseGuardrailNodeDisplay(BaseNodeVellumDisplay[_GuardrailNodeType], Generi
45
45
  "release_tag": raise_if_descriptor(node.release_tag),
46
46
  },
47
47
  "display_data": self.get_display_data().dict(),
48
+ "base": self.get_base().dict(),
48
49
  "definition": self.get_definition().dict(),
49
50
  }
@@ -59,6 +59,7 @@ class BaseInlinePromptNodeDisplay(BaseNodeVellumDisplay[_InlinePromptNodeType],
59
59
  "ml_model_name": raise_if_descriptor(node.ml_model),
60
60
  },
61
61
  "display_data": self.get_display_data().dict(),
62
+ "base": self.get_base().dict(),
62
63
  "definition": self.get_definition().dict(),
63
64
  }
64
65
 
@@ -52,6 +52,7 @@ class BaseInlineSubworkflowNodeDisplay(
52
52
  "output_variables": [workflow_output.dict() for workflow_output in workflow_outputs],
53
53
  },
54
54
  "display_data": self.get_display_data().dict(),
55
+ "base": self.get_base().dict(),
55
56
  "definition": self.get_definition().dict(),
56
57
  }
57
58
 
@@ -75,5 +75,6 @@ class BaseMapNodeDisplay(BaseNodeVellumDisplay[_MapNodeType], Generic[_MapNodeTy
75
75
  "index_input_id": index_workflow_input_id,
76
76
  },
77
77
  "display_data": self.get_display_data().dict(),
78
+ "base": self.get_base().dict(),
78
79
  "definition": self.get_definition().dict(),
79
80
  }
@@ -45,6 +45,7 @@ class BaseMergeNodeDisplay(BaseNodeVellumDisplay[_MergeNodeType], Generic[_Merge
45
45
  "source_handle_id": str(self.get_source_handle_id(display_context.port_displays)),
46
46
  },
47
47
  "display_data": self.get_display_data().dict(),
48
+ "base": self.get_base().dict(),
48
49
  "definition": self.get_definition().dict(),
49
50
  }
50
51
 
@@ -26,5 +26,6 @@ class BaseNoteNodeDisplay(BaseNodeVellumDisplay[_NoteNodeType], Generic[_NoteNod
26
26
  "style": json.dumps(self.style) if self.style else None,
27
27
  },
28
28
  "display_data": self.get_display_data().dict(),
29
+ "base": self.get_base().dict(),
29
30
  "definition": self.get_definition().dict(),
30
31
  }
@@ -64,5 +64,6 @@ class BasePromptDeploymentNodeDisplay(
64
64
  "release_tag": raise_if_descriptor(node.release_tag),
65
65
  },
66
66
  "display_data": self.get_display_data().dict(),
67
+ "base": self.get_base().dict(),
67
68
  "definition": self.get_definition().dict(),
68
69
  }
@@ -65,6 +65,7 @@ class BaseSearchNodeDisplay(BaseNodeVellumDisplay[_SearchNodeType], Generic[_Sea
65
65
  "metadata_filters_node_input_id": str(node_inputs["metadata_filters"].id),
66
66
  },
67
67
  "display_data": self.get_display_data().dict(),
68
+ "base": self.get_base().dict(),
68
69
  "definition": self.get_definition().dict(),
69
70
  }
70
71
 
@@ -56,5 +56,6 @@ class BaseSubworkflowDeploymentNodeDisplay(
56
56
  "release_tag": raise_if_descriptor(node.release_tag),
57
57
  },
58
58
  "display_data": self.get_display_data().dict(),
59
+ "base": self.get_base().dict(),
59
60
  "definition": self.get_definition().dict(),
60
61
  }
@@ -73,5 +73,6 @@ class BaseTemplatingNodeDisplay(BaseNodeVellumDisplay[_TemplatingNodeType], Gene
73
73
  "output_type": inferred_output_type,
74
74
  },
75
75
  "display_data": self.get_display_data().dict(),
76
+ "base": self.get_base().dict(),
76
77
  "definition": self.get_definition().dict(),
77
78
  }
@@ -2,7 +2,29 @@ from uuid import UUID
2
2
  from typing import Any, List, Optional, Type, Union, cast
3
3
 
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
+ from vellum.workflows.expressions.begins_with import BeginsWithExpression
6
+ from vellum.workflows.expressions.between import BetweenExpression
5
7
  from vellum.workflows.expressions.coalesce_expression import CoalesceExpression
8
+ from vellum.workflows.expressions.contains import ContainsExpression
9
+ from vellum.workflows.expressions.does_not_begin_with import DoesNotBeginWithExpression
10
+ from vellum.workflows.expressions.does_not_contain import DoesNotContainExpression
11
+ from vellum.workflows.expressions.does_not_end_with import DoesNotEndWithExpression
12
+ from vellum.workflows.expressions.does_not_equal import DoesNotEqualExpression
13
+ from vellum.workflows.expressions.ends_with import EndsWithExpression
14
+ from vellum.workflows.expressions.equals import EqualsExpression
15
+ from vellum.workflows.expressions.greater_than import GreaterThanExpression
16
+ from vellum.workflows.expressions.greater_than_or_equal_to import GreaterThanOrEqualToExpression
17
+ from vellum.workflows.expressions.in_ import InExpression
18
+ from vellum.workflows.expressions.is_nil import IsNilExpression
19
+ from vellum.workflows.expressions.is_not_nil import IsNotNilExpression
20
+ from vellum.workflows.expressions.is_not_null import IsNotNullExpression
21
+ from vellum.workflows.expressions.is_not_undefined import IsNotUndefinedExpression
22
+ from vellum.workflows.expressions.is_null import IsNullExpression
23
+ from vellum.workflows.expressions.is_undefined import IsUndefinedExpression
24
+ from vellum.workflows.expressions.less_than import LessThanExpression
25
+ from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
26
+ from vellum.workflows.expressions.not_between import NotBetweenExpression
27
+ from vellum.workflows.expressions.not_in import NotInExpression
6
28
  from vellum.workflows.nodes.utils import get_wrapped_node, has_wrapped_node
7
29
  from vellum.workflows.references import NodeReference, OutputReference
8
30
  from vellum.workflows.utils.uuids import uuid4_from_hash
@@ -109,3 +131,44 @@ def create_pointer(
109
131
  return ConstantValuePointer(type="CONSTANT_VALUE", data=vellum_variable_value)
110
132
  else:
111
133
  raise ValueError(f"Pointer type {pointer_type} not supported")
134
+
135
+
136
+ def convert_descriptor_to_operator(descriptor: BaseDescriptor) -> str:
137
+ if isinstance(descriptor, EqualsExpression):
138
+ return "="
139
+ elif isinstance(descriptor, DoesNotEqualExpression):
140
+ return "!="
141
+ elif isinstance(descriptor, LessThanExpression):
142
+ return "<"
143
+ elif isinstance(descriptor, GreaterThanExpression):
144
+ return ">"
145
+ elif isinstance(descriptor, LessThanOrEqualToExpression):
146
+ return "<="
147
+ elif isinstance(descriptor, GreaterThanOrEqualToExpression):
148
+ return ">="
149
+ elif isinstance(descriptor, ContainsExpression):
150
+ return "contains"
151
+ elif isinstance(descriptor, BeginsWithExpression):
152
+ return "beginsWith"
153
+ elif isinstance(descriptor, EndsWithExpression):
154
+ return "endsWith"
155
+ elif isinstance(descriptor, DoesNotContainExpression):
156
+ return "doesNotContain"
157
+ elif isinstance(descriptor, DoesNotBeginWithExpression):
158
+ return "doesNotBeginWith"
159
+ elif isinstance(descriptor, DoesNotEndWithExpression):
160
+ return "doesNotEndWith"
161
+ elif isinstance(descriptor, (IsNullExpression, IsNilExpression, IsUndefinedExpression)):
162
+ return "null"
163
+ elif isinstance(descriptor, (IsNotNullExpression, IsNotNilExpression, IsNotUndefinedExpression)):
164
+ return "notNull"
165
+ elif isinstance(descriptor, InExpression):
166
+ return "in"
167
+ elif isinstance(descriptor, NotInExpression):
168
+ return "notIn"
169
+ elif isinstance(descriptor, BetweenExpression):
170
+ return "between"
171
+ elif isinstance(descriptor, NotBetweenExpression):
172
+ return "notBetween"
173
+ else:
174
+ raise ValueError(f"Unsupported descriptor type: {descriptor}")
@@ -30,11 +30,8 @@ def test_vellum_workflow_display__serialize_empty_workflow():
30
30
  "nodes": [
31
31
  {
32
32
  "data": {"label": "Entrypoint Node", "source_handle_id": "508b8b82-3517-4672-a155-18c9c7b9c545"},
33
- "definition": {
34
- "bases": [],
35
- "module": ["vellum", "workflows", "nodes", "bases", "base"],
36
- "name": "BaseNode",
37
- },
33
+ "base": None,
34
+ "definition": None,
38
35
  "display_data": {"position": {"x": 0.0, "y": 0.0}},
39
36
  "id": "9eef0c18-f322-4d56-aa89-f088d3e53f6a",
40
37
  "inputs": [],
@@ -1,16 +1,29 @@
1
1
  import pytest
2
2
  from uuid import uuid4
3
+ from typing import Dict, Tuple, Type
3
4
 
5
+ from vellum.workflows.nodes.bases.base import BaseNode
6
+ from vellum.workflows.references.output import OutputReference
7
+ from vellum.workflows.references.workflow_input import WorkflowInputReference
8
+ from vellum.workflows.types.core import JsonObject
9
+ from vellum.workflows.types.generics import NodeType
10
+ from vellum_ee.workflows.display.base import WorkflowInputsDisplayType
4
11
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
12
+ from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
5
13
  from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
6
- from vellum_ee.workflows.display.types import WorkflowDisplayContext
14
+ from vellum_ee.workflows.display.types import NodeDisplayType, WorkflowDisplayContext
7
15
  from vellum_ee.workflows.display.vellum import NodeDisplayData, WorkflowMetaVellumDisplay
8
16
  from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
9
17
 
10
18
 
11
19
  @pytest.fixture()
12
20
  def serialize_node():
13
- def _serialize_node(node_class) -> dict:
21
+ def _serialize_node(
22
+ node_class: Type[NodeType],
23
+ global_workflow_input_displays: Dict[WorkflowInputReference, WorkflowInputsDisplayType] = {},
24
+ global_node_displays: Dict[Type[BaseNode], NodeDisplayType] = {},
25
+ global_node_output_displays: Dict[OutputReference, Tuple[Type[BaseNode], NodeOutputDisplay]] = {},
26
+ ) -> JsonObject:
14
27
  node_display_class = get_node_display_class(BaseNodeDisplay, node_class)
15
28
  node_display = node_display_class()
16
29
 
@@ -22,6 +35,9 @@ def serialize_node():
22
35
  entrypoint_node_display=NodeDisplayData(),
23
36
  ),
24
37
  node_displays={node_class: node_display},
38
+ global_workflow_input_displays=global_workflow_input_displays,
39
+ global_node_displays=global_node_displays,
40
+ global_node_output_displays=global_node_output_displays,
25
41
  )
26
42
  return node_display.serialize(context)
27
43