vellum-ai 0.14.38__py3-none-any.whl → 0.14.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. vellum/__init__.py +2 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +2 -0
  4. vellum/client/types/test_suite_run_progress.py +20 -0
  5. vellum/client/types/test_suite_run_read.py +3 -0
  6. vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
  7. vellum/client/types/workflow_execution_event_error_code.py +1 -0
  8. vellum/types/test_suite_run_progress.py +3 -0
  9. vellum/workflows/errors/types.py +1 -0
  10. vellum/workflows/events/tests/test_event.py +1 -0
  11. vellum/workflows/events/workflow.py +13 -3
  12. vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
  13. vellum/workflows/nodes/core/try_node/node.py +1 -2
  14. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
  15. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
  16. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +26 -0
  17. vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
  18. vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
  19. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
  20. vellum/workflows/nodes/utils.py +4 -2
  21. vellum/workflows/outputs/base.py +3 -2
  22. vellum/workflows/references/output.py +20 -0
  23. vellum/workflows/runner/runner.py +37 -17
  24. vellum/workflows/state/base.py +64 -19
  25. vellum/workflows/state/tests/test_state.py +31 -22
  26. vellum/workflows/types/stack.py +11 -0
  27. vellum/workflows/workflows/base.py +13 -18
  28. vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
  29. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
  30. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +82 -75
  31. vellum_cli/push.py +2 -5
  32. vellum_cli/tests/test_push.py +52 -0
  33. vellum_ee/workflows/display/base.py +14 -1
  34. vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
  35. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
  36. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
  37. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
  38. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
  39. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
  40. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
  41. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
  42. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
  43. vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
  45. vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
  46. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  47. vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
  48. vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
  49. vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
  50. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
  51. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
  52. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
  64. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
  65. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
  66. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
  70. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
  71. vellum_ee/workflows/display/types.py +5 -4
  72. vellum_ee/workflows/display/utils/exceptions.py +7 -0
  73. vellum_ee/workflows/display/utils/registry.py +37 -0
  74. vellum_ee/workflows/display/utils/vellum.py +2 -1
  75. vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
  76. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
  77. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
  78. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
  79. vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
  80. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
  81. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
  82. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
@@ -1,11 +1,12 @@
1
1
  from dataclasses import dataclass, field
2
- from typing import TYPE_CHECKING, Dict, Tuple, Type, TypeVar
2
+ from typing import TYPE_CHECKING, Dict, Tuple, Type
3
3
 
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.events.workflow import WorkflowEventDisplayContext # noqa: F401
6
6
  from vellum.workflows.nodes import BaseNode
7
7
  from vellum.workflows.ports import Port
8
8
  from vellum.workflows.references import OutputReference, StateValueReference, WorkflowInputReference
9
+ from vellum.workflows.workflows.base import BaseWorkflow
9
10
  from vellum_ee.workflows.display.base import (
10
11
  EdgeDisplay,
11
12
  EntrypointDisplay,
@@ -16,11 +17,11 @@ from vellum_ee.workflows.display.base import (
16
17
  )
17
18
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
18
19
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay
20
+ from vellum_ee.workflows.display.utils.registry import get_default_workflow_display_class
19
21
 
20
22
  if TYPE_CHECKING:
21
23
  from vellum_ee.workflows.display.workflows import BaseWorkflowDisplay
22
24
 
23
- WorkflowDisplayType = TypeVar("WorkflowDisplayType", bound="BaseWorkflowDisplay")
24
25
 
25
26
  WorkflowInputsDisplays = Dict[WorkflowInputReference, WorkflowInputsDisplay]
26
27
  StateValueDisplays = Dict[StateValueReference, StateValueDisplay]
@@ -34,8 +35,8 @@ PortDisplays = Dict[Port, PortDisplay]
34
35
 
35
36
  @dataclass
36
37
  class WorkflowDisplayContext:
37
- workflow_display_class: Type["BaseWorkflowDisplay"]
38
- workflow_display: WorkflowMetaDisplay
38
+ workflow_display_class: Type["BaseWorkflowDisplay"] = field(default_factory=get_default_workflow_display_class)
39
+ workflow_display: WorkflowMetaDisplay = field(default_factory=lambda: WorkflowMetaDisplay.get_default(BaseWorkflow))
39
40
  workflow_input_displays: WorkflowInputsDisplays = field(default_factory=dict)
40
41
  global_workflow_input_displays: WorkflowInputsDisplays = field(default_factory=dict)
41
42
  state_value_displays: StateValueDisplays = field(default_factory=dict)
@@ -0,0 +1,7 @@
1
+ class UserFacingException(Exception):
2
+ def to_message(self) -> str:
3
+ return str(self)
4
+
5
+
6
+ class UnsupportedSerializationException(UserFacingException):
7
+ pass
@@ -0,0 +1,37 @@
1
+ from typing import TYPE_CHECKING, Dict, Optional, Type
2
+
3
+ from vellum.workflows.nodes import BaseNode
4
+ from vellum.workflows.workflows.base import BaseWorkflow
5
+
6
+ if TYPE_CHECKING:
7
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
8
+ from vellum_ee.workflows.display.workflows.base_workflow_display import BaseWorkflowDisplay
9
+
10
+
11
+ # Used to store the mapping between workflows and their display classes
12
+ _workflow_display_registry: Dict[Type[BaseWorkflow], Type["BaseWorkflowDisplay"]] = {}
13
+
14
+ # Used to store the mapping between node types and their display classes
15
+ _node_display_registry: Dict[Type[BaseNode], Type["BaseNodeDisplay"]] = {}
16
+
17
+
18
+ def get_from_workflow_display_registry(workflow_class: Type[BaseWorkflow]) -> Optional[Type["BaseWorkflowDisplay"]]:
19
+ return _workflow_display_registry.get(workflow_class)
20
+
21
+
22
+ def register_workflow_display_class(
23
+ workflow_class: Type[BaseWorkflow], workflow_display_class: Type["BaseWorkflowDisplay"]
24
+ ) -> None:
25
+ _workflow_display_registry[workflow_class] = workflow_display_class
26
+
27
+
28
+ def get_default_workflow_display_class() -> Type["BaseWorkflowDisplay"]:
29
+ return _workflow_display_registry[BaseWorkflow]
30
+
31
+
32
+ def get_from_node_display_registry(node_class: Type[BaseNode]) -> Optional[Type["BaseNodeDisplay"]]:
33
+ return _node_display_registry.get(node_class)
34
+
35
+
36
+ def register_node_display_class(node_class: Type[BaseNode], node_display_class: Type["BaseNodeDisplay"]) -> None:
37
+ _node_display_registry[node_class] = node_display_class
@@ -41,6 +41,7 @@ from vellum.workflows.references.node import NodeReference
41
41
  from vellum.workflows.references.vellum_secret import VellumSecretReference
42
42
  from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_variable_type
43
43
  from vellum.workflows.vellum_client import create_vellum_client
44
+ from vellum_ee.workflows.display.utils.exceptions import UnsupportedSerializationException
44
45
  from vellum_ee.workflows.display.utils.expressions import get_child_descriptor
45
46
 
46
47
  if TYPE_CHECKING:
@@ -164,7 +165,7 @@ def create_node_input_value_pointer_rule(
164
165
  vellum_value = primitive_to_vellum_value(value)
165
166
  return ConstantValuePointer(type="CONSTANT_VALUE", data=vellum_value)
166
167
 
167
- raise ValueError(f"Unsupported descriptor type: {value.__class__.__name__}")
168
+ raise UnsupportedSerializationException(f"Unsupported descriptor type: {value.__class__.__name__}")
168
169
 
169
170
 
170
171
  def convert_descriptor_to_operator(descriptor: BaseDescriptor) -> LogicalOperator:
@@ -1,21 +1,25 @@
1
- from abc import abstractmethod
2
1
  from copy import copy
3
2
  from functools import cached_property
4
3
  import importlib
4
+ import inspect
5
5
  import logging
6
6
  from uuid import UUID
7
- from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, Type, Union, get_args
7
+ from typing import Any, Dict, ForwardRef, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args
8
8
 
9
9
  from vellum.workflows import BaseWorkflow
10
+ from vellum.workflows.constants import undefined
10
11
  from vellum.workflows.descriptors.base import BaseDescriptor
11
12
  from vellum.workflows.edges import Edge
12
13
  from vellum.workflows.events.workflow import NodeEventDisplayContext, WorkflowEventDisplayContext
13
14
  from vellum.workflows.nodes.bases import BaseNode
14
- from vellum.workflows.nodes.utils import get_unadorned_node, get_wrapped_node
15
+ from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
16
+ from vellum.workflows.nodes.displayable.final_output_node.node import FinalOutputNode
17
+ from vellum.workflows.nodes.utils import get_unadorned_node, get_unadorned_port, get_wrapped_node
15
18
  from vellum.workflows.ports import Port
16
19
  from vellum.workflows.references import OutputReference, WorkflowInputReference
17
- from vellum.workflows.types.core import JsonObject
20
+ from vellum.workflows.types.core import JsonArray, JsonObject
18
21
  from vellum.workflows.types.generics import WorkflowType
22
+ from vellum.workflows.types.utils import get_original_base
19
23
  from vellum.workflows.utils.uuids import uuid4_from_hash
20
24
  from vellum_ee.workflows.display.base import (
21
25
  EdgeDisplay,
@@ -31,6 +35,7 @@ from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeV
31
35
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
32
36
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay
33
37
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
38
+ from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
34
39
  from vellum_ee.workflows.display.types import (
35
40
  EdgeDisplays,
36
41
  EntrypointDisplays,
@@ -42,6 +47,8 @@ from vellum_ee.workflows.display.types import (
42
47
  WorkflowInputsDisplays,
43
48
  WorkflowOutputDisplays,
44
49
  )
50
+ from vellum_ee.workflows.display.utils.registry import register_workflow_display_class
51
+ from vellum_ee.workflows.display.utils.vellum import infer_vellum_variable_type
45
52
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
46
53
 
47
54
  logger = logging.getLogger(__name__)
@@ -69,45 +76,262 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
69
76
  # Used to explicitly specify display data for a workflow's ports.
70
77
  port_displays: PortDisplays = {}
71
78
 
72
- # Used to store the mapping between workflows and their display classes
73
- _workflow_display_registry: Dict[Type[WorkflowType], Type["BaseWorkflowDisplay"]] = {}
74
-
75
79
  _errors: List[Exception]
76
80
 
77
81
  _dry_run: bool
78
82
 
79
83
  def __init__(
80
84
  self,
81
- workflow: Type[WorkflowType],
82
85
  *,
83
86
  parent_display_context: Optional[WorkflowDisplayContext] = None,
84
87
  dry_run: bool = False,
85
88
  ):
86
- self._workflow = workflow
87
89
  self._parent_display_context = parent_display_context
88
90
  self._errors: List[Exception] = []
89
91
  self._dry_run = dry_run
90
92
 
91
- @abstractmethod
92
93
  def serialize(self) -> JsonObject:
93
- pass
94
+ input_variables: JsonArray = []
95
+ for workflow_input_reference, workflow_input_display in self.display_context.workflow_input_displays.items():
96
+ default = (
97
+ primitive_to_vellum_value(workflow_input_reference.instance)
98
+ if workflow_input_reference.instance
99
+ else None
100
+ )
101
+ input_variables.append(
102
+ {
103
+ "id": str(workflow_input_display.id),
104
+ "key": workflow_input_display.name or workflow_input_reference.name,
105
+ "type": infer_vellum_variable_type(workflow_input_reference),
106
+ "default": default.dict() if default else None,
107
+ "required": workflow_input_reference.instance is undefined,
108
+ "extensions": {"color": workflow_input_display.color},
109
+ }
110
+ )
94
111
 
95
- @classmethod
96
- def get_from_workflow_display_registry(cls, workflow_class: Type[WorkflowType]) -> Type["BaseWorkflowDisplay"]:
97
- try:
98
- return cls._workflow_display_registry[workflow_class]
99
- except KeyError:
100
- return cls._workflow_display_registry[WorkflowType] # type: ignore [misc]
112
+ state_variables: JsonArray = []
113
+ for state_value_reference, state_value_display in self.display_context.state_value_displays.items():
114
+ default = (
115
+ primitive_to_vellum_value(state_value_reference.instance) if state_value_reference.instance else None
116
+ )
117
+ state_variables.append(
118
+ {
119
+ "id": str(state_value_display.id),
120
+ "key": state_value_display.name or state_value_reference.name,
121
+ "type": infer_vellum_variable_type(state_value_reference),
122
+ "default": default.dict() if default else None,
123
+ "required": state_value_reference.instance is undefined,
124
+ "extensions": {"color": state_value_display.color},
125
+ }
126
+ )
127
+
128
+ nodes: JsonArray = []
129
+ edges: JsonArray = []
130
+
131
+ # Add a single synthetic node for the workflow entrypoint
132
+ entrypoint_node_id = self.display_context.workflow_display.entrypoint_node_id
133
+ entrypoint_node_source_handle_id = self.display_context.workflow_display.entrypoint_node_source_handle_id
134
+ nodes.append(
135
+ {
136
+ "id": str(entrypoint_node_id),
137
+ "type": "ENTRYPOINT",
138
+ "inputs": [],
139
+ "data": {
140
+ "label": "Entrypoint Node",
141
+ "source_handle_id": str(entrypoint_node_source_handle_id),
142
+ },
143
+ "display_data": self.display_context.workflow_display.entrypoint_node_display.dict(),
144
+ "base": None,
145
+ "definition": None,
146
+ },
147
+ )
148
+
149
+ # Add all the nodes in the workflow
150
+ for node in self._workflow.get_nodes():
151
+ node_display = self.display_context.node_displays[node]
152
+
153
+ try:
154
+ serialized_node = node_display.serialize(self.display_context)
155
+ except NotImplementedError as e:
156
+ self.add_error(e)
157
+ continue
158
+
159
+ nodes.append(serialized_node)
160
+
161
+ # Add all unused nodes in the workflow
162
+ for node in self._workflow.get_unused_nodes():
163
+ node_display = self.display_context.node_displays[node]
164
+
165
+ try:
166
+ serialized_node = node_display.serialize(self.display_context)
167
+ except NotImplementedError as e:
168
+ self.add_error(e)
169
+ continue
170
+
171
+ nodes.append(serialized_node)
172
+
173
+ synthetic_output_edges: JsonArray = []
174
+ output_variables: JsonArray = []
175
+ final_output_nodes = [
176
+ node for node in self.display_context.node_displays.keys() if issubclass(node, FinalOutputNode)
177
+ ]
178
+ final_output_node_outputs = {node.Outputs.value for node in final_output_nodes}
179
+ unreferenced_final_output_node_outputs = final_output_node_outputs.copy()
180
+ final_output_node_base: JsonObject = {
181
+ "name": FinalOutputNode.__name__,
182
+ "module": cast(JsonArray, FinalOutputNode.__module__.split(".")),
183
+ }
184
+
185
+ # Add a synthetic Terminal Node and track the Workflow's output variables for each Workflow output
186
+ for workflow_output, workflow_output_display in self.display_context.workflow_output_displays.items():
187
+ final_output_node_id = uuid4_from_hash(f"{self.workflow_id}|node_id|{workflow_output.name}")
188
+ inferred_type = infer_vellum_variable_type(workflow_output)
189
+
190
+ # Remove the terminal node output from the unreferenced set
191
+ unreferenced_final_output_node_outputs.discard(cast(OutputReference, workflow_output.instance))
192
+
193
+ if workflow_output.instance not in final_output_node_outputs:
194
+ # Create a synthetic terminal node only if there is no terminal node for this output
195
+ try:
196
+ node_input = create_node_input(
197
+ final_output_node_id,
198
+ "node_input",
199
+ # This is currently the wrapper node's output, but we want the wrapped node
200
+ workflow_output.instance,
201
+ self.display_context,
202
+ )
203
+ except ValueError as e:
204
+ raise ValueError(f"Failed to serialize output '{workflow_output.name}': {str(e)}") from e
205
+
206
+ source_node_display: Optional[BaseNodeDisplay]
207
+ first_rule = node_input.value.rules[0]
208
+ if first_rule.type == "NODE_OUTPUT":
209
+ source_node_id = UUID(first_rule.data.node_id)
210
+ try:
211
+ source_node_display = [
212
+ node_display
213
+ for node_display in self.display_context.node_displays.values()
214
+ if node_display.node_id == source_node_id
215
+ ][0]
216
+ except IndexError:
217
+ source_node_display = None
218
+
219
+ synthetic_target_handle_id = str(
220
+ uuid4_from_hash(f"{self.workflow_id}|target_handle_id|{workflow_output_display.name}")
221
+ )
222
+ synthetic_display_data = NodeDisplayData().dict()
223
+ synthetic_node_label = "Final Output"
224
+ nodes.append(
225
+ {
226
+ "id": str(final_output_node_id),
227
+ "type": "TERMINAL",
228
+ "data": {
229
+ "label": synthetic_node_label,
230
+ "name": workflow_output_display.name,
231
+ "target_handle_id": synthetic_target_handle_id,
232
+ "output_id": str(workflow_output_display.id),
233
+ "output_type": inferred_type,
234
+ "node_input_id": str(node_input.id),
235
+ },
236
+ "inputs": [node_input.dict()],
237
+ "display_data": synthetic_display_data,
238
+ "base": final_output_node_base,
239
+ "definition": None,
240
+ }
241
+ )
242
+
243
+ if source_node_display:
244
+ if isinstance(source_node_display, BaseNodeVellumDisplay):
245
+ source_handle_id = source_node_display.get_source_handle_id(
246
+ port_displays=self.display_context.port_displays
247
+ )
248
+ else:
249
+ source_handle_id = source_node_display.get_node_port_display(
250
+ source_node_display._node.Ports.default
251
+ ).id
252
+
253
+ synthetic_output_edges.append(
254
+ {
255
+ "id": str(uuid4_from_hash(f"{self.workflow_id}|edge_id|{workflow_output_display.name}")),
256
+ "source_node_id": str(source_node_display.node_id),
257
+ "source_handle_id": str(source_handle_id),
258
+ "target_node_id": str(final_output_node_id),
259
+ "target_handle_id": synthetic_target_handle_id,
260
+ "type": "DEFAULT",
261
+ }
262
+ )
263
+
264
+ output_variables.append(
265
+ {
266
+ "id": str(workflow_output_display.id),
267
+ "key": workflow_output_display.name,
268
+ "type": inferred_type,
269
+ }
270
+ )
271
+
272
+ # If there are terminal nodes with no workflow output reference,
273
+ # raise a serialization error
274
+ if len(unreferenced_final_output_node_outputs) > 0:
275
+ self.add_error(
276
+ ValueError("Unable to serialize terminal nodes that are not referenced by workflow outputs.")
277
+ )
278
+
279
+ # Add an edge for each edge in the workflow
280
+ for target_node, entrypoint_display in self.display_context.entrypoint_displays.items():
281
+ unadorned_target_node = get_unadorned_node(target_node)
282
+ target_node_display = self.display_context.node_displays[unadorned_target_node]
283
+ edges.append(
284
+ {
285
+ "id": str(entrypoint_display.edge_display.id),
286
+ "source_node_id": str(entrypoint_node_id),
287
+ "source_handle_id": str(entrypoint_node_source_handle_id),
288
+ "target_node_id": str(target_node_display.node_id),
289
+ "target_handle_id": str(target_node_display.get_trigger_id()),
290
+ "type": "DEFAULT",
291
+ }
292
+ )
293
+
294
+ for (source_node_port, target_node), edge_display in self.display_context.edge_displays.items():
295
+ unadorned_source_node_port = get_unadorned_port(source_node_port)
296
+ unadorned_target_node = get_unadorned_node(target_node)
297
+
298
+ source_node_port_display = self.display_context.port_displays[unadorned_source_node_port]
299
+ target_node_display = self.display_context.node_displays[unadorned_target_node]
300
+
301
+ edges.append(
302
+ {
303
+ "id": str(edge_display.id),
304
+ "source_node_id": str(source_node_port_display.node_id),
305
+ "source_handle_id": str(source_node_port_display.id),
306
+ "target_node_id": str(target_node_display.node_id),
307
+ "target_handle_id": str(
308
+ target_node_display.get_target_handle_id_by_source_node_id(source_node_port_display.node_id)
309
+ ),
310
+ "type": "DEFAULT",
311
+ }
312
+ )
313
+
314
+ edges.extend(synthetic_output_edges)
315
+
316
+ return {
317
+ "workflow_raw_data": {
318
+ "nodes": nodes,
319
+ "edges": edges,
320
+ "display_data": self.display_context.workflow_display.display_data.dict(),
321
+ "definition": {
322
+ "name": self._workflow.__name__,
323
+ "module": cast(JsonArray, self._workflow.__module__.split(".")),
324
+ },
325
+ },
326
+ "input_variables": input_variables,
327
+ "state_variables": state_variables,
328
+ "output_variables": output_variables,
329
+ }
101
330
 
102
331
  @cached_property
103
332
  def workflow_id(self) -> UUID:
104
333
  """Can be overridden as a class attribute to specify a custom workflow id."""
105
- return uuid4_from_hash(self._workflow.__qualname__)
106
-
107
- @property
108
- @abstractmethod
109
- def node_display_base_class(self) -> Type[BaseNodeDisplay]:
110
- pass
334
+ return self._workflow.__id__
111
335
 
112
336
  def add_error(self, error: Exception) -> None:
113
337
  if self._dry_run:
@@ -159,13 +383,8 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
159
383
  port_displays[port] = node_display.get_node_port_display(port)
160
384
 
161
385
  def _get_node_display(self, node: Type[BaseNode]) -> BaseNodeDisplay:
162
- node_display_class = get_node_display_class(self.node_display_base_class, node)
163
- node_display = node_display_class()
164
-
165
- if not isinstance(node_display, self.node_display_base_class):
166
- raise ValueError(f"{node.__name__} must be a subclass of {self.node_display_base_class.__name__}")
167
-
168
- return node_display
386
+ node_display_class = get_node_display_class(node)
387
+ return node_display_class()
169
388
 
170
389
  @cached_property
171
390
  def display_context(self) -> WorkflowDisplayContext:
@@ -294,14 +513,7 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
294
513
  display_data=overrides.display_data,
295
514
  )
296
515
 
297
- entrypoint_node_id = uuid4_from_hash(f"{self.workflow_id}|entrypoint_node_id")
298
- entrypoint_node_source_handle_id = uuid4_from_hash(f"{self.workflow_id}|entrypoint_node_source_handle_id")
299
-
300
- return WorkflowMetaDisplay(
301
- entrypoint_node_id=entrypoint_node_id,
302
- entrypoint_node_source_handle_id=entrypoint_node_source_handle_id,
303
- entrypoint_node_display=NodeDisplayData(),
304
- )
516
+ return WorkflowMetaDisplay.get_default(self._workflow)
305
517
 
306
518
  def _generate_workflow_input_display(
307
519
  self, workflow_input: WorkflowInputReference, overrides: Optional[WorkflowInputsDisplay] = None
@@ -368,11 +580,13 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
368
580
  super().__init_subclass__(**kwargs)
369
581
 
370
582
  workflow_class = get_args(cls.__orig_bases__[0])[0] # type: ignore [attr-defined]
371
- cls._workflow_display_registry[workflow_class] = cls
583
+ register_workflow_display_class(workflow_class=workflow_class, workflow_display_class=cls)
372
584
 
373
585
  @staticmethod
374
586
  def gather_event_display_context(
375
- module_path: str, workflow_class: Type[BaseWorkflow]
587
+ module_path: str,
588
+ # DEPRECATED: This will be removed in the 0.15.0 release
589
+ workflow_class: Optional[Type[BaseWorkflow]] = None,
376
590
  ) -> Union[WorkflowEventDisplayContext, None]:
377
591
  workflow_display_module = f"{module_path}.display.workflow"
378
592
  try:
@@ -380,11 +594,11 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
380
594
  except ModuleNotFoundError:
381
595
  return None
382
596
 
383
- workflow_display = display_class.WorkflowDisplay(workflow_class)
384
- if not isinstance(workflow_display, BaseWorkflowDisplay):
597
+ WorkflowDisplayClass = display_class.WorkflowDisplay
598
+ if not isinstance(WorkflowDisplayClass, type) or not issubclass(WorkflowDisplayClass, BaseWorkflowDisplay):
385
599
  return None
386
600
 
387
- return workflow_display.get_event_display_context()
601
+ return WorkflowDisplayClass().get_event_display_context()
388
602
 
389
603
  def get_event_display_context(self):
390
604
  display_context = self.display_context
@@ -494,3 +708,27 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
494
708
  return EdgeDisplay(
495
709
  id=uuid4_from_hash(f"{self.workflow_id}|id|{source_node_id}|{target_node_id}"),
496
710
  )
711
+
712
+ @classmethod
713
+ def infer_workflow_class(cls) -> Type[BaseWorkflow]:
714
+ original_base = get_original_base(cls)
715
+ workflow_class = get_args(original_base)[0]
716
+ if isinstance(workflow_class, TypeVar):
717
+ bounded_class = workflow_class.__bound__
718
+ if inspect.isclass(bounded_class) and issubclass(bounded_class, BaseWorkflow):
719
+ return bounded_class
720
+
721
+ if isinstance(bounded_class, ForwardRef) and bounded_class.__forward_arg__ == BaseWorkflow.__name__:
722
+ return BaseWorkflow
723
+
724
+ if issubclass(workflow_class, BaseWorkflow):
725
+ return workflow_class
726
+
727
+ raise ValueError(f"Workflow {cls.__name__} must be a subclass of {BaseWorkflow.__name__}")
728
+
729
+ @property
730
+ def _workflow(self) -> Type[WorkflowType]:
731
+ return cast(Type[WorkflowType], self.__class__.infer_workflow_class())
732
+
733
+
734
+ register_workflow_display_class(workflow_class=BaseWorkflow, workflow_display_class=BaseWorkflowDisplay)
@@ -1,33 +1,46 @@
1
- from typing import Optional, Type
1
+ import types
2
+ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar
2
3
 
3
4
  from vellum.workflows.types.generics import WorkflowType
4
- from vellum_ee.workflows.display.types import WorkflowDisplayContext, WorkflowDisplayType
5
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
6
+ from vellum_ee.workflows.display.utils.registry import get_from_workflow_display_registry
7
+
8
+ if TYPE_CHECKING:
9
+ from vellum_ee.workflows.display.workflows import BaseWorkflowDisplay
10
+
11
+
12
+ def _get_workflow_display_class(*, workflow_class: Type[WorkflowType]) -> Type["BaseWorkflowDisplay"]:
13
+ workflow_display_class = get_from_workflow_display_registry(workflow_class)
14
+ if workflow_display_class:
15
+ return workflow_display_class
16
+
17
+ base_workflow_display_class = _get_workflow_display_class(
18
+ workflow_class=workflow_class.__bases__[0],
19
+ )
20
+
21
+ # mypy gets upset at dynamic TypeVar's, but it's technically allowed by python
22
+ _WorkflowClassType = TypeVar(f"_{workflow_class.__name__}Type", bound=workflow_class) # type: ignore[misc]
23
+ # `base_workflow_display_class` is always a Generic class, so it's safe to index into it
24
+ WorkflowDisplayBaseClass = base_workflow_display_class[_WorkflowClassType] # type: ignore[index]
25
+
26
+ WorkflowDisplayClass = types.new_class(
27
+ f"{workflow_class.__name__}Display",
28
+ bases=(WorkflowDisplayBaseClass, Generic[_WorkflowClassType]),
29
+ )
30
+
31
+ return WorkflowDisplayClass
5
32
 
6
33
 
7
34
  def get_workflow_display(
8
35
  *,
9
- base_display_class: Type[WorkflowDisplayType],
10
36
  workflow_class: Type[WorkflowType],
11
- root_workflow_class: Optional[Type[WorkflowType]] = None,
12
37
  parent_display_context: Optional[WorkflowDisplayContext] = None,
13
38
  dry_run: bool = False,
14
- ) -> WorkflowDisplayType:
15
- try:
16
- workflow_display_class = base_display_class.get_from_workflow_display_registry(workflow_class)
17
- except KeyError:
18
- try:
19
- return get_workflow_display(
20
- base_display_class=base_display_class,
21
- workflow_class=workflow_class.__bases__[0],
22
- root_workflow_class=workflow_class if root_workflow_class is None else root_workflow_class,
23
- parent_display_context=parent_display_context,
24
- dry_run=dry_run,
25
- )
26
- except IndexError:
27
- return base_display_class(workflow_class)
28
-
29
- return workflow_display_class( # type: ignore[return-value]
30
- workflow_class,
39
+ # DEPRECATED: The following arguments will be removed in 0.15.0
40
+ root_workflow_class: Optional[Type[WorkflowType]] = None,
41
+ base_display_class: Optional[Type["BaseWorkflowDisplay"]] = None,
42
+ ) -> "BaseWorkflowDisplay":
43
+ return _get_workflow_display_class(workflow_class=workflow_class)(
31
44
  parent_display_context=parent_display_context,
32
45
  dry_run=dry_run,
33
46
  )