vellum-ai 0.14.39__py3-none-any.whl → 0.14.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/reference.md +138 -1
  3. vellum/client/resources/ad_hoc/client.py +311 -1
  4. vellum/client/resources/deployments/client.py +2 -2
  5. vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
  6. vellum/workflows/nodes/core/try_node/node.py +1 -2
  7. vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
  8. vellum/workflows/nodes/experimental/tool_calling_node/node.py +125 -0
  9. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +128 -0
  10. vellum/workflows/nodes/utils.py +4 -2
  11. vellum/workflows/outputs/base.py +3 -2
  12. vellum/workflows/references/output.py +20 -0
  13. vellum/workflows/state/base.py +36 -14
  14. vellum/workflows/state/tests/test_state.py +5 -2
  15. vellum/workflows/types/stack.py +11 -0
  16. vellum/workflows/workflows/base.py +5 -0
  17. vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
  18. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/METADATA +1 -1
  19. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/RECORD +84 -80
  20. vellum_cli/push.py +0 -2
  21. vellum_ee/workflows/display/base.py +14 -1
  22. vellum_ee/workflows/display/nodes/base_node_display.py +91 -19
  23. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
  24. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +54 -0
  25. vellum_ee/workflows/display/nodes/vellum/api_node.py +2 -2
  26. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +4 -4
  27. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -2
  28. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +2 -2
  29. vellum_ee/workflows/display/nodes/vellum/error_node.py +2 -2
  30. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +2 -2
  31. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +2 -2
  32. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +2 -2
  33. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +2 -2
  34. vellum_ee/workflows/display/nodes/vellum/merge_node.py +2 -2
  35. vellum_ee/workflows/display/nodes/vellum/note_node.py +2 -2
  36. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -4
  37. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
  38. vellum_ee/workflows/display/nodes/vellum/search_node.py +2 -2
  39. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +2 -2
  40. vellum_ee/workflows/display/nodes/vellum/templating_node.py +2 -2
  41. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
  42. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
  43. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
  45. vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
  46. vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
  47. vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
  48. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +4 -4
  49. vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
  50. vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
  51. vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
  52. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
  53. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
  54. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
  64. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
  65. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
  66. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
  70. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
  71. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
  72. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
  73. vellum_ee/workflows/display/types.py +5 -4
  74. vellum_ee/workflows/display/utils/exceptions.py +7 -0
  75. vellum_ee/workflows/display/utils/registry.py +37 -0
  76. vellum_ee/workflows/display/utils/vellum.py +2 -1
  77. vellum_ee/workflows/display/workflows/base_workflow_display.py +277 -47
  78. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
  79. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
  80. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
  81. vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
  82. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +0 -40
  83. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/LICENSE +0 -0
  84. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/WHEEL +0 -0
  85. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/entry_points.txt +0 -0
@@ -2,7 +2,6 @@ import pytest
2
2
 
3
3
  from deepdiff import DeepDiff
4
4
 
5
- from vellum_ee.workflows.display.workflows import VellumWorkflowDisplay
6
5
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
7
6
 
8
7
  from tests.workflows.complex_final_output_node.missing_final_output_node import MissingFinalOutputNodeWorkflow
@@ -11,9 +10,7 @@ from tests.workflows.complex_final_output_node.missing_workflow_output import Mi
11
10
 
12
11
  def test_serialize_workflow__missing_final_output_node():
13
12
  # GIVEN a Workflow that is missing a Terminal Node
14
- workflow_display = get_workflow_display(
15
- base_display_class=VellumWorkflowDisplay, workflow_class=MissingFinalOutputNodeWorkflow
16
- )
13
+ workflow_display = get_workflow_display(workflow_class=MissingFinalOutputNodeWorkflow)
17
14
 
18
15
  # WHEN we serialize it
19
16
  serialized_workflow: dict = workflow_display.serialize()
@@ -179,9 +176,7 @@ def test_serialize_workflow__missing_final_output_node():
179
176
 
180
177
  def test_serialize_workflow__missing_workflow_output():
181
178
  # GIVEN a Workflow that contains a terminal node that is unreferenced by the Workflow's Outputs
182
- workflow_display = get_workflow_display(
183
- base_display_class=VellumWorkflowDisplay, workflow_class=MissingWorkflowOutputWorkflow
184
- )
179
+ workflow_display = get_workflow_display(workflow_class=MissingWorkflowOutputWorkflow)
185
180
 
186
181
  # WHEN we serialize it, it should throw an error
187
182
  with pytest.raises(ValueError) as exc_info:
@@ -1,11 +1,12 @@
1
1
  from dataclasses import dataclass, field
2
- from typing import TYPE_CHECKING, Dict, Tuple, Type, TypeVar
2
+ from typing import TYPE_CHECKING, Dict, Tuple, Type
3
3
 
4
4
  from vellum.workflows.descriptors.base import BaseDescriptor
5
5
  from vellum.workflows.events.workflow import WorkflowEventDisplayContext # noqa: F401
6
6
  from vellum.workflows.nodes import BaseNode
7
7
  from vellum.workflows.ports import Port
8
8
  from vellum.workflows.references import OutputReference, StateValueReference, WorkflowInputReference
9
+ from vellum.workflows.workflows.base import BaseWorkflow
9
10
  from vellum_ee.workflows.display.base import (
10
11
  EdgeDisplay,
11
12
  EntrypointDisplay,
@@ -16,11 +17,11 @@ from vellum_ee.workflows.display.base import (
16
17
  )
17
18
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
18
19
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay
20
+ from vellum_ee.workflows.display.utils.registry import get_default_workflow_display_class
19
21
 
20
22
  if TYPE_CHECKING:
21
23
  from vellum_ee.workflows.display.workflows import BaseWorkflowDisplay
22
24
 
23
- WorkflowDisplayType = TypeVar("WorkflowDisplayType", bound="BaseWorkflowDisplay")
24
25
 
25
26
  WorkflowInputsDisplays = Dict[WorkflowInputReference, WorkflowInputsDisplay]
26
27
  StateValueDisplays = Dict[StateValueReference, StateValueDisplay]
@@ -34,8 +35,8 @@ PortDisplays = Dict[Port, PortDisplay]
34
35
 
35
36
  @dataclass
36
37
  class WorkflowDisplayContext:
37
- workflow_display_class: Type["BaseWorkflowDisplay"]
38
- workflow_display: WorkflowMetaDisplay
38
+ workflow_display_class: Type["BaseWorkflowDisplay"] = field(default_factory=get_default_workflow_display_class)
39
+ workflow_display: WorkflowMetaDisplay = field(default_factory=lambda: WorkflowMetaDisplay.get_default(BaseWorkflow))
39
40
  workflow_input_displays: WorkflowInputsDisplays = field(default_factory=dict)
40
41
  global_workflow_input_displays: WorkflowInputsDisplays = field(default_factory=dict)
41
42
  state_value_displays: StateValueDisplays = field(default_factory=dict)
@@ -0,0 +1,7 @@
1
+ class UserFacingException(Exception):
2
+ def to_message(self) -> str:
3
+ return str(self)
4
+
5
+
6
+ class UnsupportedSerializationException(UserFacingException):
7
+ pass
@@ -0,0 +1,37 @@
1
+ from typing import TYPE_CHECKING, Dict, Optional, Type
2
+
3
+ from vellum.workflows.nodes import BaseNode
4
+ from vellum.workflows.workflows.base import BaseWorkflow
5
+
6
+ if TYPE_CHECKING:
7
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
8
+ from vellum_ee.workflows.display.workflows.base_workflow_display import BaseWorkflowDisplay
9
+
10
+
11
+ # Used to store the mapping between workflows and their display classes
12
+ _workflow_display_registry: Dict[Type[BaseWorkflow], Type["BaseWorkflowDisplay"]] = {}
13
+
14
+ # Used to store the mapping between node types and their display classes
15
+ _node_display_registry: Dict[Type[BaseNode], Type["BaseNodeDisplay"]] = {}
16
+
17
+
18
+ def get_from_workflow_display_registry(workflow_class: Type[BaseWorkflow]) -> Optional[Type["BaseWorkflowDisplay"]]:
19
+ return _workflow_display_registry.get(workflow_class)
20
+
21
+
22
+ def register_workflow_display_class(
23
+ workflow_class: Type[BaseWorkflow], workflow_display_class: Type["BaseWorkflowDisplay"]
24
+ ) -> None:
25
+ _workflow_display_registry[workflow_class] = workflow_display_class
26
+
27
+
28
+ def get_default_workflow_display_class() -> Type["BaseWorkflowDisplay"]:
29
+ return _workflow_display_registry[BaseWorkflow]
30
+
31
+
32
+ def get_from_node_display_registry(node_class: Type[BaseNode]) -> Optional[Type["BaseNodeDisplay"]]:
33
+ return _node_display_registry.get(node_class)
34
+
35
+
36
+ def register_node_display_class(node_class: Type[BaseNode], node_display_class: Type["BaseNodeDisplay"]) -> None:
37
+ _node_display_registry[node_class] = node_display_class
@@ -41,6 +41,7 @@ from vellum.workflows.references.node import NodeReference
41
41
  from vellum.workflows.references.vellum_secret import VellumSecretReference
42
42
  from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_variable_type
43
43
  from vellum.workflows.vellum_client import create_vellum_client
44
+ from vellum_ee.workflows.display.utils.exceptions import UnsupportedSerializationException
44
45
  from vellum_ee.workflows.display.utils.expressions import get_child_descriptor
45
46
 
46
47
  if TYPE_CHECKING:
@@ -164,7 +165,7 @@ def create_node_input_value_pointer_rule(
164
165
  vellum_value = primitive_to_vellum_value(value)
165
166
  return ConstantValuePointer(type="CONSTANT_VALUE", data=vellum_value)
166
167
 
167
- raise ValueError(f"Unsupported descriptor type: {value.__class__.__name__}")
168
+ raise UnsupportedSerializationException(f"Unsupported descriptor type: {value.__class__.__name__}")
168
169
 
169
170
 
170
171
  def convert_descriptor_to_operator(descriptor: BaseDescriptor) -> LogicalOperator:
@@ -1,21 +1,25 @@
1
- from abc import abstractmethod
2
1
  from copy import copy
3
2
  from functools import cached_property
4
3
  import importlib
4
+ import inspect
5
5
  import logging
6
6
  from uuid import UUID
7
- from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, Type, Union, get_args
7
+ from typing import Any, Dict, ForwardRef, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args
8
8
 
9
9
  from vellum.workflows import BaseWorkflow
10
+ from vellum.workflows.constants import undefined
10
11
  from vellum.workflows.descriptors.base import BaseDescriptor
11
12
  from vellum.workflows.edges import Edge
12
13
  from vellum.workflows.events.workflow import NodeEventDisplayContext, WorkflowEventDisplayContext
13
14
  from vellum.workflows.nodes.bases import BaseNode
14
- from vellum.workflows.nodes.utils import get_unadorned_node, get_wrapped_node
15
+ from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
16
+ from vellum.workflows.nodes.displayable.final_output_node.node import FinalOutputNode
17
+ from vellum.workflows.nodes.utils import get_unadorned_node, get_unadorned_port, get_wrapped_node
15
18
  from vellum.workflows.ports import Port
16
19
  from vellum.workflows.references import OutputReference, WorkflowInputReference
17
- from vellum.workflows.types.core import JsonObject
20
+ from vellum.workflows.types.core import JsonArray, JsonObject
18
21
  from vellum.workflows.types.generics import WorkflowType
22
+ from vellum.workflows.types.utils import get_original_base
19
23
  from vellum.workflows.utils.uuids import uuid4_from_hash
20
24
  from vellum_ee.workflows.display.base import (
21
25
  EdgeDisplay,
@@ -27,10 +31,10 @@ from vellum_ee.workflows.display.base import (
27
31
  )
28
32
  from vellum_ee.workflows.display.editor.types import NodeDisplayData
29
33
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
30
- from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
31
34
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
32
35
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay
33
36
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
37
+ from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
34
38
  from vellum_ee.workflows.display.types import (
35
39
  EdgeDisplays,
36
40
  EntrypointDisplays,
@@ -42,6 +46,8 @@ from vellum_ee.workflows.display.types import (
42
46
  WorkflowInputsDisplays,
43
47
  WorkflowOutputDisplays,
44
48
  )
49
+ from vellum_ee.workflows.display.utils.registry import register_workflow_display_class
50
+ from vellum_ee.workflows.display.utils.vellum import infer_vellum_variable_type
45
51
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
46
52
 
47
53
  logger = logging.getLogger(__name__)
@@ -69,45 +75,257 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
69
75
  # Used to explicitly specify display data for a workflow's ports.
70
76
  port_displays: PortDisplays = {}
71
77
 
72
- # Used to store the mapping between workflows and their display classes
73
- _workflow_display_registry: Dict[Type[WorkflowType], Type["BaseWorkflowDisplay"]] = {}
74
-
75
78
  _errors: List[Exception]
76
79
 
77
80
  _dry_run: bool
78
81
 
79
82
  def __init__(
80
83
  self,
81
- workflow: Type[WorkflowType],
82
84
  *,
83
85
  parent_display_context: Optional[WorkflowDisplayContext] = None,
84
86
  dry_run: bool = False,
85
87
  ):
86
- self._workflow = workflow
87
88
  self._parent_display_context = parent_display_context
88
89
  self._errors: List[Exception] = []
89
90
  self._dry_run = dry_run
90
91
 
91
- @abstractmethod
92
92
  def serialize(self) -> JsonObject:
93
- pass
93
+ input_variables: JsonArray = []
94
+ for workflow_input_reference, workflow_input_display in self.display_context.workflow_input_displays.items():
95
+ default = (
96
+ primitive_to_vellum_value(workflow_input_reference.instance)
97
+ if workflow_input_reference.instance
98
+ else None
99
+ )
100
+ input_variables.append(
101
+ {
102
+ "id": str(workflow_input_display.id),
103
+ "key": workflow_input_display.name or workflow_input_reference.name,
104
+ "type": infer_vellum_variable_type(workflow_input_reference),
105
+ "default": default.dict() if default else None,
106
+ "required": workflow_input_reference.instance is undefined,
107
+ "extensions": {"color": workflow_input_display.color},
108
+ }
109
+ )
94
110
 
95
- @classmethod
96
- def get_from_workflow_display_registry(cls, workflow_class: Type[WorkflowType]) -> Type["BaseWorkflowDisplay"]:
97
- try:
98
- return cls._workflow_display_registry[workflow_class]
99
- except KeyError:
100
- return cls._workflow_display_registry[WorkflowType] # type: ignore [misc]
111
+ state_variables: JsonArray = []
112
+ for state_value_reference, state_value_display in self.display_context.state_value_displays.items():
113
+ default = (
114
+ primitive_to_vellum_value(state_value_reference.instance) if state_value_reference.instance else None
115
+ )
116
+ state_variables.append(
117
+ {
118
+ "id": str(state_value_display.id),
119
+ "key": state_value_display.name or state_value_reference.name,
120
+ "type": infer_vellum_variable_type(state_value_reference),
121
+ "default": default.dict() if default else None,
122
+ "required": state_value_reference.instance is undefined,
123
+ "extensions": {"color": state_value_display.color},
124
+ }
125
+ )
126
+
127
+ nodes: JsonArray = []
128
+ edges: JsonArray = []
129
+
130
+ # Add a single synthetic node for the workflow entrypoint
131
+ entrypoint_node_id = self.display_context.workflow_display.entrypoint_node_id
132
+ entrypoint_node_source_handle_id = self.display_context.workflow_display.entrypoint_node_source_handle_id
133
+ nodes.append(
134
+ {
135
+ "id": str(entrypoint_node_id),
136
+ "type": "ENTRYPOINT",
137
+ "inputs": [],
138
+ "data": {
139
+ "label": "Entrypoint Node",
140
+ "source_handle_id": str(entrypoint_node_source_handle_id),
141
+ },
142
+ "display_data": self.display_context.workflow_display.entrypoint_node_display.dict(),
143
+ "base": None,
144
+ "definition": None,
145
+ },
146
+ )
147
+
148
+ # Add all the nodes in the workflow
149
+ for node in self._workflow.get_nodes():
150
+ node_display = self.display_context.node_displays[node]
151
+
152
+ try:
153
+ serialized_node = node_display.serialize(self.display_context)
154
+ except NotImplementedError as e:
155
+ self.add_error(e)
156
+ continue
157
+
158
+ nodes.append(serialized_node)
159
+
160
+ # Add all unused nodes in the workflow
161
+ for node in self._workflow.get_unused_nodes():
162
+ node_display = self.display_context.node_displays[node]
163
+
164
+ try:
165
+ serialized_node = node_display.serialize(self.display_context)
166
+ except NotImplementedError as e:
167
+ self.add_error(e)
168
+ continue
169
+
170
+ nodes.append(serialized_node)
171
+
172
+ synthetic_output_edges: JsonArray = []
173
+ output_variables: JsonArray = []
174
+ final_output_nodes = [
175
+ node for node in self.display_context.node_displays.keys() if issubclass(node, FinalOutputNode)
176
+ ]
177
+ final_output_node_outputs = {node.Outputs.value for node in final_output_nodes}
178
+ unreferenced_final_output_node_outputs = final_output_node_outputs.copy()
179
+ final_output_node_base: JsonObject = {
180
+ "name": FinalOutputNode.__name__,
181
+ "module": cast(JsonArray, FinalOutputNode.__module__.split(".")),
182
+ }
183
+
184
+ # Add a synthetic Terminal Node and track the Workflow's output variables for each Workflow output
185
+ for workflow_output, workflow_output_display in self.display_context.workflow_output_displays.items():
186
+ final_output_node_id = uuid4_from_hash(f"{self.workflow_id}|node_id|{workflow_output.name}")
187
+ inferred_type = infer_vellum_variable_type(workflow_output)
188
+
189
+ # Remove the terminal node output from the unreferenced set
190
+ unreferenced_final_output_node_outputs.discard(cast(OutputReference, workflow_output.instance))
191
+
192
+ if workflow_output.instance not in final_output_node_outputs:
193
+ # Create a synthetic terminal node only if there is no terminal node for this output
194
+ try:
195
+ node_input = create_node_input(
196
+ final_output_node_id,
197
+ "node_input",
198
+ # This is currently the wrapper node's output, but we want the wrapped node
199
+ workflow_output.instance,
200
+ self.display_context,
201
+ )
202
+ except ValueError as e:
203
+ raise ValueError(f"Failed to serialize output '{workflow_output.name}': {str(e)}") from e
204
+
205
+ source_node_display: Optional[BaseNodeDisplay]
206
+ first_rule = node_input.value.rules[0]
207
+ if first_rule.type == "NODE_OUTPUT":
208
+ source_node_id = UUID(first_rule.data.node_id)
209
+ try:
210
+ source_node_display = [
211
+ node_display
212
+ for node_display in self.display_context.node_displays.values()
213
+ if node_display.node_id == source_node_id
214
+ ][0]
215
+ except IndexError:
216
+ source_node_display = None
217
+
218
+ synthetic_target_handle_id = str(
219
+ uuid4_from_hash(f"{self.workflow_id}|target_handle_id|{workflow_output_display.name}")
220
+ )
221
+ synthetic_display_data = NodeDisplayData().dict()
222
+ synthetic_node_label = "Final Output"
223
+ nodes.append(
224
+ {
225
+ "id": str(final_output_node_id),
226
+ "type": "TERMINAL",
227
+ "data": {
228
+ "label": synthetic_node_label,
229
+ "name": workflow_output_display.name,
230
+ "target_handle_id": synthetic_target_handle_id,
231
+ "output_id": str(workflow_output_display.id),
232
+ "output_type": inferred_type,
233
+ "node_input_id": str(node_input.id),
234
+ },
235
+ "inputs": [node_input.dict()],
236
+ "display_data": synthetic_display_data,
237
+ "base": final_output_node_base,
238
+ "definition": None,
239
+ }
240
+ )
241
+
242
+ if source_node_display:
243
+ source_handle_id = source_node_display.get_source_handle_id(
244
+ port_displays=self.display_context.port_displays
245
+ )
246
+
247
+ synthetic_output_edges.append(
248
+ {
249
+ "id": str(uuid4_from_hash(f"{self.workflow_id}|edge_id|{workflow_output_display.name}")),
250
+ "source_node_id": str(source_node_display.node_id),
251
+ "source_handle_id": str(source_handle_id),
252
+ "target_node_id": str(final_output_node_id),
253
+ "target_handle_id": synthetic_target_handle_id,
254
+ "type": "DEFAULT",
255
+ }
256
+ )
257
+
258
+ output_variables.append(
259
+ {
260
+ "id": str(workflow_output_display.id),
261
+ "key": workflow_output_display.name,
262
+ "type": inferred_type,
263
+ }
264
+ )
265
+
266
+ # If there are terminal nodes with no workflow output reference,
267
+ # raise a serialization error
268
+ if len(unreferenced_final_output_node_outputs) > 0:
269
+ self.add_error(
270
+ ValueError("Unable to serialize terminal nodes that are not referenced by workflow outputs.")
271
+ )
272
+
273
+ # Add an edge for each edge in the workflow
274
+ for target_node, entrypoint_display in self.display_context.entrypoint_displays.items():
275
+ unadorned_target_node = get_unadorned_node(target_node)
276
+ target_node_display = self.display_context.node_displays[unadorned_target_node]
277
+ edges.append(
278
+ {
279
+ "id": str(entrypoint_display.edge_display.id),
280
+ "source_node_id": str(entrypoint_node_id),
281
+ "source_handle_id": str(entrypoint_node_source_handle_id),
282
+ "target_node_id": str(target_node_display.node_id),
283
+ "target_handle_id": str(target_node_display.get_trigger_id()),
284
+ "type": "DEFAULT",
285
+ }
286
+ )
287
+
288
+ for (source_node_port, target_node), edge_display in self.display_context.edge_displays.items():
289
+ unadorned_source_node_port = get_unadorned_port(source_node_port)
290
+ unadorned_target_node = get_unadorned_node(target_node)
291
+
292
+ source_node_port_display = self.display_context.port_displays[unadorned_source_node_port]
293
+ target_node_display = self.display_context.node_displays[unadorned_target_node]
294
+
295
+ edges.append(
296
+ {
297
+ "id": str(edge_display.id),
298
+ "source_node_id": str(source_node_port_display.node_id),
299
+ "source_handle_id": str(source_node_port_display.id),
300
+ "target_node_id": str(target_node_display.node_id),
301
+ "target_handle_id": str(
302
+ target_node_display.get_target_handle_id_by_source_node_id(source_node_port_display.node_id)
303
+ ),
304
+ "type": "DEFAULT",
305
+ }
306
+ )
307
+
308
+ edges.extend(synthetic_output_edges)
309
+
310
+ return {
311
+ "workflow_raw_data": {
312
+ "nodes": nodes,
313
+ "edges": edges,
314
+ "display_data": self.display_context.workflow_display.display_data.dict(),
315
+ "definition": {
316
+ "name": self._workflow.__name__,
317
+ "module": cast(JsonArray, self._workflow.__module__.split(".")),
318
+ },
319
+ },
320
+ "input_variables": input_variables,
321
+ "state_variables": state_variables,
322
+ "output_variables": output_variables,
323
+ }
101
324
 
102
325
  @cached_property
103
326
  def workflow_id(self) -> UUID:
104
327
  """Can be overridden as a class attribute to specify a custom workflow id."""
105
- return uuid4_from_hash(self._workflow.__qualname__)
106
-
107
- @property
108
- @abstractmethod
109
- def node_display_base_class(self) -> Type[BaseNodeDisplay]:
110
- pass
328
+ return self._workflow.__id__
111
329
 
112
330
  def add_error(self, error: Exception) -> None:
113
331
  if self._dry_run:
@@ -159,13 +377,8 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
159
377
  port_displays[port] = node_display.get_node_port_display(port)
160
378
 
161
379
  def _get_node_display(self, node: Type[BaseNode]) -> BaseNodeDisplay:
162
- node_display_class = get_node_display_class(self.node_display_base_class, node)
163
- node_display = node_display_class()
164
-
165
- if not isinstance(node_display, self.node_display_base_class):
166
- raise ValueError(f"{node.__name__} must be a subclass of {self.node_display_base_class.__name__}")
167
-
168
- return node_display
380
+ node_display_class = get_node_display_class(node)
381
+ return node_display_class()
169
382
 
170
383
  @cached_property
171
384
  def display_context(self) -> WorkflowDisplayContext:
@@ -294,14 +507,7 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
294
507
  display_data=overrides.display_data,
295
508
  )
296
509
 
297
- entrypoint_node_id = uuid4_from_hash(f"{self.workflow_id}|entrypoint_node_id")
298
- entrypoint_node_source_handle_id = uuid4_from_hash(f"{self.workflow_id}|entrypoint_node_source_handle_id")
299
-
300
- return WorkflowMetaDisplay(
301
- entrypoint_node_id=entrypoint_node_id,
302
- entrypoint_node_source_handle_id=entrypoint_node_source_handle_id,
303
- entrypoint_node_display=NodeDisplayData(),
304
- )
510
+ return WorkflowMetaDisplay.get_default(self._workflow)
305
511
 
306
512
  def _generate_workflow_input_display(
307
513
  self, workflow_input: WorkflowInputReference, overrides: Optional[WorkflowInputsDisplay] = None
@@ -368,11 +574,13 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
368
574
  super().__init_subclass__(**kwargs)
369
575
 
370
576
  workflow_class = get_args(cls.__orig_bases__[0])[0] # type: ignore [attr-defined]
371
- cls._workflow_display_registry[workflow_class] = cls
577
+ register_workflow_display_class(workflow_class=workflow_class, workflow_display_class=cls)
372
578
 
373
579
  @staticmethod
374
580
  def gather_event_display_context(
375
- module_path: str, workflow_class: Type[BaseWorkflow]
581
+ module_path: str,
582
+ # DEPRECATED: This will be removed in the 0.15.0 release
583
+ workflow_class: Optional[Type[BaseWorkflow]] = None,
376
584
  ) -> Union[WorkflowEventDisplayContext, None]:
377
585
  workflow_display_module = f"{module_path}.display.workflow"
378
586
  try:
@@ -380,11 +588,11 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
380
588
  except ModuleNotFoundError:
381
589
  return None
382
590
 
383
- workflow_display = display_class.WorkflowDisplay(workflow_class)
384
- if not isinstance(workflow_display, BaseWorkflowDisplay):
591
+ WorkflowDisplayClass = display_class.WorkflowDisplay
592
+ if not isinstance(WorkflowDisplayClass, type) or not issubclass(WorkflowDisplayClass, BaseWorkflowDisplay):
385
593
  return None
386
594
 
387
- return workflow_display.get_event_display_context()
595
+ return WorkflowDisplayClass().get_event_display_context()
388
596
 
389
597
  def get_event_display_context(self):
390
598
  display_context = self.display_context
@@ -403,9 +611,7 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
403
611
  node_event_displays = {}
404
612
  for node_id in node_displays:
405
613
  node, current_node_display = node_displays[node_id]
406
- input_display = {}
407
- if isinstance(current_node_display, BaseNodeVellumDisplay):
408
- input_display = current_node_display.node_input_ids_by_name
614
+ input_display = current_node_display.node_input_ids_by_name
409
615
  output_display = {
410
616
  output.name: current_node_display.output_display[output].id
411
617
  for output in current_node_display.output_display
@@ -494,3 +700,27 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
494
700
  return EdgeDisplay(
495
701
  id=uuid4_from_hash(f"{self.workflow_id}|id|{source_node_id}|{target_node_id}"),
496
702
  )
703
+
704
+ @classmethod
705
+ def infer_workflow_class(cls) -> Type[BaseWorkflow]:
706
+ original_base = get_original_base(cls)
707
+ workflow_class = get_args(original_base)[0]
708
+ if isinstance(workflow_class, TypeVar):
709
+ bounded_class = workflow_class.__bound__
710
+ if inspect.isclass(bounded_class) and issubclass(bounded_class, BaseWorkflow):
711
+ return bounded_class
712
+
713
+ if isinstance(bounded_class, ForwardRef) and bounded_class.__forward_arg__ == BaseWorkflow.__name__:
714
+ return BaseWorkflow
715
+
716
+ if issubclass(workflow_class, BaseWorkflow):
717
+ return workflow_class
718
+
719
+ raise ValueError(f"Workflow {cls.__name__} must be a subclass of {BaseWorkflow.__name__}")
720
+
721
+ @property
722
+ def _workflow(self) -> Type[WorkflowType]:
723
+ return cast(Type[WorkflowType], self.__class__.infer_workflow_class())
724
+
725
+
726
+ register_workflow_display_class(workflow_class=BaseWorkflow, workflow_display_class=BaseWorkflowDisplay)
@@ -1,33 +1,46 @@
1
- from typing import Optional, Type
1
+ import types
2
+ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar
2
3
 
3
4
  from vellum.workflows.types.generics import WorkflowType
4
- from vellum_ee.workflows.display.types import WorkflowDisplayContext, WorkflowDisplayType
5
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
6
+ from vellum_ee.workflows.display.utils.registry import get_from_workflow_display_registry
7
+
8
+ if TYPE_CHECKING:
9
+ from vellum_ee.workflows.display.workflows import BaseWorkflowDisplay
10
+
11
+
12
+ def _get_workflow_display_class(*, workflow_class: Type[WorkflowType]) -> Type["BaseWorkflowDisplay"]:
13
+ workflow_display_class = get_from_workflow_display_registry(workflow_class)
14
+ if workflow_display_class:
15
+ return workflow_display_class
16
+
17
+ base_workflow_display_class = _get_workflow_display_class(
18
+ workflow_class=workflow_class.__bases__[0],
19
+ )
20
+
21
+ # mypy gets upset at dynamic TypeVar's, but it's technically allowed by python
22
+ _WorkflowClassType = TypeVar(f"_{workflow_class.__name__}Type", bound=workflow_class) # type: ignore[misc]
23
+ # `base_workflow_display_class` is always a Generic class, so it's safe to index into it
24
+ WorkflowDisplayBaseClass = base_workflow_display_class[_WorkflowClassType] # type: ignore[index]
25
+
26
+ WorkflowDisplayClass = types.new_class(
27
+ f"{workflow_class.__name__}Display",
28
+ bases=(WorkflowDisplayBaseClass, Generic[_WorkflowClassType]),
29
+ )
30
+
31
+ return WorkflowDisplayClass
5
32
 
6
33
 
7
34
  def get_workflow_display(
8
35
  *,
9
- base_display_class: Type[WorkflowDisplayType],
10
36
  workflow_class: Type[WorkflowType],
11
- root_workflow_class: Optional[Type[WorkflowType]] = None,
12
37
  parent_display_context: Optional[WorkflowDisplayContext] = None,
13
38
  dry_run: bool = False,
14
- ) -> WorkflowDisplayType:
15
- try:
16
- workflow_display_class = base_display_class.get_from_workflow_display_registry(workflow_class)
17
- except KeyError:
18
- try:
19
- return get_workflow_display(
20
- base_display_class=base_display_class,
21
- workflow_class=workflow_class.__bases__[0],
22
- root_workflow_class=workflow_class if root_workflow_class is None else root_workflow_class,
23
- parent_display_context=parent_display_context,
24
- dry_run=dry_run,
25
- )
26
- except IndexError:
27
- return base_display_class(workflow_class)
28
-
29
- return workflow_display_class( # type: ignore[return-value]
30
- workflow_class,
39
+ # DEPRECATED: The following arguments will be removed in 0.15.0
40
+ root_workflow_class: Optional[Type[WorkflowType]] = None,
41
+ base_display_class: Optional[Type["BaseWorkflowDisplay"]] = None,
42
+ ) -> "BaseWorkflowDisplay":
43
+ return _get_workflow_display_class(workflow_class=workflow_class)(
31
44
  parent_display_context=parent_display_context,
32
45
  dry_run=dry_run,
33
46
  )