vellum-ai 0.13.10__py3-none-any.whl → 0.13.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/errors/types.py +21 -0
  3. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +14 -1
  4. vellum/workflows/nodes/displayable/bases/tests/test_utils.py +18 -0
  5. vellum/workflows/nodes/displayable/bases/utils.py +8 -1
  6. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +55 -0
  7. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.11.dist-info}/METADATA +1 -1
  8. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.11.dist-info}/RECORD +31 -31
  9. vellum_cli/__init__.py +9 -1
  10. vellum_cli/config.py +29 -1
  11. vellum_cli/push.py +24 -3
  12. vellum_cli/tests/conftest.py +3 -0
  13. vellum_cli/tests/test_pull.py +6 -0
  14. vellum_cli/tests/test_push.py +88 -1
  15. vellum_ee/workflows/display/nodes/base_node_display.py +110 -6
  16. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +16 -1
  17. vellum_ee/workflows/display/nodes/get_node_display_class.py +6 -4
  18. vellum_ee/workflows/display/nodes/vellum/__init__.py +0 -2
  19. vellum_ee/workflows/display/nodes/vellum/error_node.py +9 -3
  20. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +44 -0
  21. vellum_ee/workflows/display/nodes/vellum/try_node.py +3 -4
  22. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +1 -1
  23. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +5 -2
  24. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +1 -1
  25. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_outputs_serialization.py +1 -1
  26. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +1 -1
  27. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +15 -1
  28. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +27 -13
  29. vellum_ee/workflows/display/nodes/vellum/base_node.py +0 -121
  30. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.11.dist-info}/LICENSE +0 -0
  31. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.11.dist-info}/WHEEL +0 -0
  32. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.11.dist-info}/entry_points.txt +0 -0
@@ -3,6 +3,7 @@ import io
3
3
  import json
4
4
  import os
5
5
  import tarfile
6
+ from unittest import mock
6
7
  from uuid import uuid4
7
8
 
8
9
  from click.testing import CliRunner
@@ -194,7 +195,7 @@ def test_push__dry_run_option_returns_report(mock_module, vellum_client):
194
195
  from typing import Dict
195
196
  from vellum.workflows import BaseWorkflow
196
197
  from vellum.workflows.nodes import BaseNode
197
- from vellum_ee.workflows.display.nodes import BaseNodeDisplay
198
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
198
199
 
199
200
  class NotSupportedNode(BaseNode):
200
201
  pass
@@ -318,3 +319,89 @@ Files that were different between the original project and the generated artifac
318
319
  \x1b[0m
319
320
  """
320
321
  )
322
+
323
+
324
+ def test_push__workspace_option__uses_different_api_key(mock_module, vellum_client_class):
325
+ # GIVEN a single workflow configured
326
+ temp_dir = mock_module.temp_dir
327
+ module = mock_module.module
328
+ workflow_sandbox_id = mock_module.workflow_sandbox_id
329
+ set_pyproject_toml = mock_module.set_pyproject_toml
330
+
331
+ # AND a different workspace is set in the pyproject.toml
332
+ set_pyproject_toml(
333
+ {
334
+ "workflows": [
335
+ {
336
+ "module": module,
337
+ "workflow_sandbox_id": workflow_sandbox_id,
338
+ }
339
+ ],
340
+ "workspaces": [
341
+ {
342
+ "name": "my_other_workspace",
343
+ "api_key": "MY_OTHER_VELLUM_API_KEY",
344
+ }
345
+ ],
346
+ }
347
+ )
348
+
349
+ # AND the .env file has the other api key stored
350
+ with open(os.path.join(temp_dir, ".env"), "w") as f:
351
+ f.write(
352
+ """\
353
+ VELLUM_API_KEY=abcdef123456
354
+ MY_OTHER_VELLUM_API_KEY=aaabbbcccddd
355
+ """
356
+ )
357
+
358
+ # AND a workflow exists in the module successfully
359
+ base_dir = os.path.join(temp_dir, *module.split("."))
360
+ os.makedirs(base_dir, exist_ok=True)
361
+ workflow_py_file_content = """\
362
+ from vellum.workflows import BaseWorkflow
363
+
364
+ class ExampleWorkflow(BaseWorkflow):
365
+ pass
366
+ """
367
+ with open(os.path.join(temp_dir, *module.split("."), "workflow.py"), "w") as f:
368
+ f.write(workflow_py_file_content)
369
+
370
+ # AND the push API call returns a new workflow sandbox id
371
+ new_workflow_sandbox_id = str(uuid4())
372
+ vellum_client_class.return_value.workflows.push.return_value = WorkflowPushResponse(
373
+ workflow_sandbox_id=new_workflow_sandbox_id,
374
+ )
375
+
376
+ # WHEN calling `vellum push` on strict mode
377
+ runner = CliRunner()
378
+ result = runner.invoke(cli_main, ["push", module, "--workspace", "my_other_workspace"])
379
+
380
+ # THEN it should succeed
381
+ assert result.exit_code == 0, result.output
382
+
383
+ # AND we should have called the push API once
384
+ vellum_client_class.return_value.workflows.push.assert_called_once()
385
+
386
+ # AND the workflow sandbox id arg passed in should be `None`
387
+ call_args = vellum_client_class.return_value.workflows.push.call_args.kwargs
388
+ assert call_args["workflow_sandbox_id"] is None
389
+
390
+ # AND with the correct api key
391
+ vellum_client_class.assert_called_once_with(
392
+ api_key="aaabbbcccddd",
393
+ environment=mock.ANY,
394
+ )
395
+
396
+ # AND the vellum lock file should have been updated with the correct workspace
397
+ with open(os.path.join(temp_dir, "vellum.lock.json")) as f:
398
+ lock_file_content = json.load(f)
399
+ assert lock_file_content["workflows"][1] == {
400
+ "module": module,
401
+ "workflow_sandbox_id": new_workflow_sandbox_id,
402
+ "workspace": "my_other_workspace",
403
+ "container_image_name": None,
404
+ "container_image_tag": None,
405
+ "deployments": [],
406
+ "ignore": None,
407
+ }
@@ -16,6 +16,8 @@ from typing import (
16
16
  get_origin,
17
17
  )
18
18
 
19
+ from vellum.workflows import BaseWorkflow
20
+ from vellum.workflows.constants import UNDEF
19
21
  from vellum.workflows.descriptors.base import BaseDescriptor
20
22
  from vellum.workflows.expressions.between import BetweenExpression
21
23
  from vellum.workflows.expressions.is_nil import IsNilExpression
@@ -26,19 +28,22 @@ from vellum.workflows.expressions.is_null import IsNullExpression
26
28
  from vellum.workflows.expressions.is_undefined import IsUndefinedExpression
27
29
  from vellum.workflows.expressions.not_between import NotBetweenExpression
28
30
  from vellum.workflows.nodes.bases.base import BaseNode
31
+ from vellum.workflows.nodes.utils import get_wrapped_node
29
32
  from vellum.workflows.ports import Port
30
33
  from vellum.workflows.references import OutputReference
31
34
  from vellum.workflows.references.execution_count import ExecutionCountReference
32
35
  from vellum.workflows.references.vellum_secret import VellumSecretReference
33
36
  from vellum.workflows.references.workflow_input import WorkflowInputReference
34
- from vellum.workflows.types.core import JsonObject
37
+ from vellum.workflows.types.core import JsonArray, JsonObject
35
38
  from vellum.workflows.types.generics import NodeType
36
39
  from vellum.workflows.types.utils import get_original_base
37
40
  from vellum.workflows.utils.names import pascal_to_title_case
38
41
  from vellum.workflows.utils.uuids import uuid4_from_hash
42
+ from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_variable_type
43
+ from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
39
44
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay, PortDisplayOverrides
40
45
  from vellum_ee.workflows.display.utils.vellum import convert_descriptor_to_operator, primitive_to_vellum_value
41
- from vellum_ee.workflows.display.vellum import CodeResourceDefinition
46
+ from vellum_ee.workflows.display.vellum import CodeResourceDefinition, GenericNodeDisplayData
42
47
 
43
48
  if TYPE_CHECKING:
44
49
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
@@ -65,7 +70,97 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
65
70
  _node_display_registry: Dict[Type[NodeType], Type["BaseNodeDisplay"]] = {}
66
71
 
67
72
  def serialize(self, display_context: "WorkflowDisplayContext", **kwargs: Any) -> JsonObject:
68
- raise NotImplementedError(f"Serialization for nodes of type {self._node.__name__} is not supported.")
73
+ node = self._node
74
+ node_id = self.node_id
75
+
76
+ attributes: JsonArray = []
77
+ for attribute in node:
78
+ if inspect.isclass(attribute.instance) and issubclass(attribute.instance, BaseWorkflow):
79
+ # We don't need to serialize generic node attributes containing a subworkflow
80
+ continue
81
+
82
+ id = str(uuid4_from_hash(f"{node_id}|{attribute.name}"))
83
+ attributes.append(
84
+ {
85
+ "id": id,
86
+ "name": attribute.name,
87
+ "value": self.serialize_value(display_context, cast(BaseDescriptor, attribute.instance)),
88
+ }
89
+ )
90
+
91
+ adornments = kwargs.get("adornments", None)
92
+ wrapped_node = get_wrapped_node(node)
93
+ if wrapped_node is not None:
94
+ display_class = get_node_display_class(BaseNodeDisplay, wrapped_node)
95
+
96
+ adornment: JsonObject = {
97
+ "id": str(node_id),
98
+ "label": node.__qualname__,
99
+ "base": self.get_base().dict(),
100
+ "attributes": attributes,
101
+ }
102
+
103
+ existing_adornments = adornments if adornments is not None else []
104
+ return display_class().serialize(display_context, adornments=existing_adornments + [adornment])
105
+
106
+ ports: JsonArray = []
107
+ for port in node.Ports:
108
+ id = str(self.get_node_port_display(port).id)
109
+
110
+ if port._condition_type:
111
+ ports.append(
112
+ {
113
+ "id": id,
114
+ "name": port.name,
115
+ "type": port._condition_type.value,
116
+ "expression": (
117
+ self.serialize_condition(display_context, port._condition) if port._condition else None
118
+ ),
119
+ }
120
+ )
121
+ else:
122
+ ports.append(
123
+ {
124
+ "id": id,
125
+ "name": port.name,
126
+ "type": "DEFAULT",
127
+ }
128
+ )
129
+
130
+ outputs: JsonArray = []
131
+ for output in node.Outputs:
132
+ type = primitive_type_to_vellum_variable_type(output)
133
+ value = (
134
+ self.serialize_value(display_context, output.instance)
135
+ if output.instance is not None and output.instance != UNDEF
136
+ else None
137
+ )
138
+
139
+ outputs.append(
140
+ {
141
+ "id": str(uuid4_from_hash(f"{node_id}|{output.name}")),
142
+ "name": output.name,
143
+ "type": type,
144
+ "value": value,
145
+ }
146
+ )
147
+
148
+ return {
149
+ "id": str(node_id),
150
+ "label": node.__qualname__,
151
+ "type": "GENERIC",
152
+ "display_data": self._get_generic_node_display_data().dict(),
153
+ "base": self.get_base().dict(),
154
+ "definition": self.get_definition().dict(),
155
+ "trigger": {
156
+ "id": str(self.get_trigger_id()),
157
+ "merge_behavior": node.Trigger.merge_behavior.value,
158
+ },
159
+ "ports": ports,
160
+ "adornments": adornments,
161
+ "attributes": attributes,
162
+ "outputs": outputs,
163
+ }
69
164
 
70
165
  def get_base(self) -> CodeResourceDefinition:
71
166
  node = self._node
@@ -89,9 +184,6 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
89
184
  )
90
185
  return node_definition
91
186
 
92
- def get_trigger_id(self) -> UUID:
93
- return uuid4_from_hash(f"{self.node_id}|trigger")
94
-
95
187
  def get_node_output_display(self, output: OutputReference) -> Tuple[Type[BaseNode], NodeOutputDisplay]:
96
188
  explicit_display = self.output_display.get(output)
97
189
  if explicit_display:
@@ -110,6 +202,9 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
110
202
 
111
203
  return PortDisplay(id=port_id, node_id=self.node_id)
112
204
 
205
+ def get_trigger_id(self) -> UUID:
206
+ return uuid4_from_hash(f"{self.node_id}|trigger")
207
+
113
208
  @classmethod
114
209
  def get_from_node_display_registry(cls, node_class: Type[NodeType]) -> Type["BaseNodeDisplay"]:
115
210
  return cls._node_display_registry[node_class]
@@ -187,10 +282,19 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
187
282
 
188
283
  def __init_subclass__(cls, **kwargs: Any) -> None:
189
284
  super().__init_subclass__(**kwargs)
285
+ if not cls._node_display_registry:
286
+ cls._node_display_registry[BaseNode] = BaseNodeDisplay
190
287
 
191
288
  node_class = cls.infer_node_class()
289
+ if node_class is BaseNode:
290
+ return
291
+
192
292
  cls._node_display_registry[node_class] = cls
193
293
 
294
+ def _get_generic_node_display_data(self) -> GenericNodeDisplayData:
295
+ explicit_value = self._get_explicit_node_display_attr("display_data", GenericNodeDisplayData)
296
+ return explicit_value if explicit_value else GenericNodeDisplayData()
297
+
194
298
  def serialize_condition(self, display_context: "WorkflowDisplayContext", condition: BaseDescriptor) -> JsonObject:
195
299
  if isinstance(
196
300
  condition,
@@ -7,7 +7,7 @@ from vellum.workflows.types.generics import NodeType
7
7
  from vellum.workflows.utils.uuids import uuid4_from_hash
8
8
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
9
9
  from vellum_ee.workflows.display.nodes.types import PortDisplay
10
- from vellum_ee.workflows.display.vellum import NodeDisplayData
10
+ from vellum_ee.workflows.display.vellum import NodeDisplayComment, NodeDisplayData
11
11
 
12
12
 
13
13
  class BaseNodeVellumDisplay(BaseNodeDisplay[NodeType]):
@@ -26,6 +26,21 @@ class BaseNodeVellumDisplay(BaseNodeDisplay[NodeType]):
26
26
 
27
27
  def get_display_data(self) -> NodeDisplayData:
28
28
  explicit_value = self._get_explicit_node_display_attr("display_data", NodeDisplayData)
29
+ docstring = self._node.__doc__
30
+
31
+ if explicit_value and explicit_value.comment and docstring:
32
+ comment = (
33
+ NodeDisplayComment(value=docstring, expanded=explicit_value.comment.expanded)
34
+ if explicit_value.comment.expanded
35
+ else NodeDisplayComment(value=docstring)
36
+ )
37
+ return NodeDisplayData(
38
+ position=explicit_value.position,
39
+ width=explicit_value.width,
40
+ height=explicit_value.height,
41
+ comment=comment,
42
+ )
43
+
29
44
  return explicit_value if explicit_value else NodeDisplayData()
30
45
 
31
46
  def get_target_handle_id(self) -> UUID:
@@ -1,13 +1,15 @@
1
1
  import types
2
- from typing import Optional, Type
2
+ from typing import TYPE_CHECKING, Optional, Type
3
3
 
4
4
  from vellum.workflows.types.generics import NodeType
5
- from vellum_ee.workflows.display.types import NodeDisplayType
5
+
6
+ if TYPE_CHECKING:
7
+ from vellum_ee.workflows.display.types import NodeDisplayType
6
8
 
7
9
 
8
10
  def get_node_display_class(
9
- base_class: Type[NodeDisplayType], node_class: Type[NodeType], root_node_class: Optional[Type[NodeType]] = None
10
- ) -> Type[NodeDisplayType]:
11
+ base_class: Type["NodeDisplayType"], node_class: Type[NodeType], root_node_class: Optional[Type[NodeType]] = None
12
+ ) -> Type["NodeDisplayType"]:
11
13
  try:
12
14
  node_display_class = base_class.get_from_node_display_registry(node_class)
13
15
  except KeyError:
@@ -1,5 +1,4 @@
1
1
  from .api_node import BaseAPINodeDisplay
2
- from .base_node import BaseNodeDisplay
3
2
  from .code_execution_node import BaseCodeExecutionNodeDisplay
4
3
  from .conditional_node import BaseConditionalNodeDisplay
5
4
  from .error_node import BaseErrorNodeDisplay
@@ -28,7 +27,6 @@ __all__ = [
28
27
  "BaseInlineSubworkflowNodeDisplay",
29
28
  "BaseMapNodeDisplay",
30
29
  "BaseMergeNodeDisplay",
31
- "BaseNodeDisplay",
32
30
  "BaseNoteNodeDisplay",
33
31
  "BasePromptDeploymentNodeDisplay",
34
32
  "BaseSearchNodeDisplay",
@@ -1,9 +1,10 @@
1
1
  from uuid import UUID
2
- from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar
2
+ from typing import ClassVar, Generic, Optional, TypeVar
3
3
 
4
4
  from vellum.workflows.nodes import ErrorNode
5
5
  from vellum.workflows.types.core import JsonObject
6
6
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
7
+ from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
7
8
  from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
8
9
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
9
10
 
@@ -12,7 +13,7 @@ _ErrorNodeType = TypeVar("_ErrorNodeType", bound=ErrorNode)
12
13
 
13
14
  class BaseErrorNodeDisplay(BaseNodeVellumDisplay[_ErrorNodeType], Generic[_ErrorNodeType]):
14
15
  error_output_id: ClassVar[Optional[UUID]] = None
15
- error_inputs_by_name: ClassVar[Dict[str, Any]] = {}
16
+
16
17
  name: ClassVar[str] = "error-node"
17
18
 
18
19
  def serialize(
@@ -21,6 +22,11 @@ class BaseErrorNodeDisplay(BaseNodeVellumDisplay[_ErrorNodeType], Generic[_Error
21
22
  node_id = self.node_id
22
23
  error_source_input_id = self.node_input_ids_by_name.get("error_source_input_id")
23
24
 
25
+ error_attribute = raise_if_descriptor(self._node.error)
26
+ input_values_by_name = {
27
+ "error_source_input_id": error_attribute,
28
+ }
29
+
24
30
  node_inputs = [
25
31
  create_node_input(
26
32
  node_id=node_id,
@@ -29,7 +35,7 @@ class BaseErrorNodeDisplay(BaseNodeVellumDisplay[_ErrorNodeType], Generic[_Error
29
35
  display_context=display_context,
30
36
  input_id=self.node_input_ids_by_name.get(variable_name),
31
37
  )
32
- for variable_name, variable_value in self.error_inputs_by_name.items()
38
+ for variable_name, variable_value in input_values_by_name.items()
33
39
  ]
34
40
 
35
41
  return {
@@ -0,0 +1,44 @@
1
+ from typing import Any, Dict, cast
2
+
3
+ from vellum.client.types.vellum_error import VellumError
4
+ from vellum.workflows import BaseWorkflow
5
+ from vellum.workflows.nodes.core.error_node.node import ErrorNode
6
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
7
+ from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
8
+
9
+
10
+ def test_error_node_display__serialize_with_vellum_error() -> None:
11
+ # GIVEN an Error Node with a VellumError
12
+ class MyNode(ErrorNode):
13
+ error = VellumError(
14
+ message="A bad thing happened",
15
+ code="USER_DEFINED_ERROR",
16
+ )
17
+
18
+ # AND a workflow referencing the two node
19
+ class MyWorkflow(BaseWorkflow):
20
+ graph = MyNode
21
+
22
+ # WHEN we serialize the workflow
23
+ workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=MyWorkflow)
24
+ serialized_workflow = cast(Dict[str, Any], workflow_display.serialize())
25
+
26
+ # THEN the correct inputs should be serialized on the node
27
+ serialized_node = next(
28
+ node for node in serialized_workflow["workflow_raw_data"]["nodes"] if node["id"] == str(MyNode.__id__)
29
+ )
30
+ assert serialized_node["inputs"][0]["value"] == {
31
+ "combinator": "OR",
32
+ "rules": [
33
+ {
34
+ "data": {
35
+ "type": "ERROR",
36
+ "value": {
37
+ "message": "A bad thing happened",
38
+ "code": "USER_DEFINED_ERROR",
39
+ },
40
+ },
41
+ "type": "CONSTANT_VALUE",
42
+ }
43
+ ],
44
+ }
@@ -13,7 +13,6 @@ from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeV
13
13
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
14
14
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
15
15
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
16
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay as GenericBaseNodeDisplay
17
16
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
18
17
 
19
18
  _TryNodeType = TypeVar("_TryNodeType", bound=TryNode)
@@ -36,14 +35,14 @@ class BaseTryNodeDisplay(BaseNodeVellumDisplay[_TryNodeType], Generic[_TryNodeTy
36
35
  inner_node = subworkflow.graph
37
36
  elif inner_node.__bases__[0] is BaseNode:
38
37
  # If the wrapped node is a generic node, we let generic node do adornment handling
39
- class TryBaseNodeDisplay(GenericBaseNodeDisplay[node]): # type: ignore[valid-type]
38
+ class TryBaseNodeDisplay(BaseNodeDisplay[node]): # type: ignore[valid-type]
40
39
  pass
41
40
 
42
41
  return TryBaseNodeDisplay().serialize(display_context)
43
42
 
44
43
  # We need the node display class of the underlying node because
45
44
  # it contains the logic for serializing the node and potential display overrides
46
- node_display_class = get_node_display_class(BaseNodeVellumDisplay, inner_node)
45
+ node_display_class = get_node_display_class(BaseNodeDisplay, inner_node)
47
46
  node_display = node_display_class()
48
47
 
49
48
  serialized_node = node_display.serialize(
@@ -70,7 +69,7 @@ class BaseTryNodeDisplay(BaseNodeVellumDisplay[_TryNodeType], Generic[_TryNodeTy
70
69
  if not inner_node:
71
70
  return super().get_node_output_display(output)
72
71
 
73
- node_display_class = get_node_display_class(BaseNodeVellumDisplay, inner_node)
72
+ node_display_class = get_node_display_class(BaseNodeDisplay, inner_node)
74
73
  node_display = node_display_class()
75
74
  if output.name == "error":
76
75
  return inner_node, NodeOutputDisplay(
@@ -8,9 +8,9 @@ from vellum.workflows.references.workflow_input import WorkflowInputReference
8
8
  from vellum.workflows.types.core import JsonObject
9
9
  from vellum.workflows.types.generics import NodeType
10
10
  from vellum_ee.workflows.display.base import WorkflowInputsDisplayType
11
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
12
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
12
13
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
13
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
14
14
  from vellum_ee.workflows.display.types import NodeDisplayType, WorkflowDisplayContext
15
15
  from vellum_ee.workflows.display.vellum import NodeDisplayData, WorkflowMetaVellumDisplay
16
16
  from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
@@ -8,8 +8,8 @@ from vellum.workflows.nodes.core.retry_node.node import RetryNode
8
8
  from vellum.workflows.nodes.core.try_node.node import TryNode
9
9
  from vellum.workflows.outputs.base import BaseOutputs
10
10
  from vellum_ee.workflows.display.base import WorkflowInputsDisplay
11
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
12
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
12
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
13
13
  from vellum_ee.workflows.display.nodes.vellum.try_node import BaseTryNodeDisplay
14
14
 
15
15
 
@@ -159,7 +159,10 @@ def test_serialize_node__try(serialize_node):
159
159
  {
160
160
  "id": "3344083c-a32c-4a32-920b-0fb5093448fa",
161
161
  "label": "TryNode",
162
- "base": {"name": "TryNode", "module": ["vellum", "workflows", "nodes", "core", "try_node", "node"]},
162
+ "base": {
163
+ "name": "TryNode",
164
+ "module": ["vellum", "workflows", "nodes", "core", "try_node", "node"],
165
+ },
163
166
  "attributes": [
164
167
  {
165
168
  "id": "ab2fbab0-e2a0-419b-b1ef-ce11ecf11e90",
@@ -6,8 +6,8 @@ from vellum.workflows.inputs.base import BaseInputs
6
6
  from vellum.workflows.nodes.bases.base import BaseNode
7
7
  from vellum.workflows.references.vellum_secret import VellumSecretReference
8
8
  from vellum_ee.workflows.display.base import WorkflowInputsDisplay
9
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
9
10
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
10
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
11
11
 
12
12
 
13
13
  class Inputs(BaseInputs):
@@ -5,8 +5,8 @@ from deepdiff import DeepDiff
5
5
  from vellum.workflows.inputs.base import BaseInputs
6
6
  from vellum.workflows.nodes.bases.base import BaseNode
7
7
  from vellum_ee.workflows.display.base import WorkflowInputsDisplay
8
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
8
9
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
9
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
10
10
 
11
11
 
12
12
  class Inputs(BaseInputs):
@@ -7,8 +7,8 @@ from vellum.workflows.nodes.bases.base import BaseNode
7
7
  from vellum.workflows.ports.port import Port
8
8
  from vellum.workflows.references.vellum_secret import VellumSecretReference
9
9
  from vellum_ee.workflows.display.base import WorkflowInputsDisplay
10
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
10
11
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
11
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
12
12
 
13
13
 
14
14
  class Inputs(BaseInputs):
@@ -88,7 +88,21 @@ def test_serialize_workflow():
88
88
  {
89
89
  "id": "5cf9c5e3-0eae-4daf-8d73-8b9536258eb9",
90
90
  "type": "ERROR",
91
- "inputs": [],
91
+ "inputs": [
92
+ {
93
+ "id": "690d825f-6ffd-493e-8141-c86d384e6150",
94
+ "key": "error_source_input_id",
95
+ "value": {
96
+ "rules": [
97
+ {
98
+ "type": "CONSTANT_VALUE",
99
+ "data": {"type": "STRING", "value": "Input threshold was too low"},
100
+ }
101
+ ],
102
+ "combinator": "OR",
103
+ },
104
+ }
105
+ ],
92
106
  "data": {
93
107
  "name": "error-node",
94
108
  "label": "Fail Node",
@@ -14,6 +14,7 @@ from vellum.workflows.references.output import OutputReference
14
14
  from vellum.workflows.types.core import JsonArray, JsonObject
15
15
  from vellum.workflows.types.generics import WorkflowType
16
16
  from vellum.workflows.utils.uuids import uuid4_from_hash
17
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
17
18
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
18
19
  from vellum_ee.workflows.display.nodes.types import PortDisplay
19
20
  from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
@@ -43,7 +44,7 @@ class VellumWorkflowDisplay(
43
44
  WorkflowMetaVellumDisplayOverrides,
44
45
  WorkflowInputsVellumDisplay,
45
46
  WorkflowInputsVellumDisplayOverrides,
46
- BaseNodeVellumDisplay,
47
+ BaseNodeDisplay,
47
48
  EntrypointVellumDisplay,
48
49
  EntrypointVellumDisplayOverrides,
49
50
  EdgeVellumDisplay,
@@ -52,7 +53,7 @@ class VellumWorkflowDisplay(
52
53
  WorkflowOutputVellumDisplayOverrides,
53
54
  ]
54
55
  ):
55
- node_display_base_class = BaseNodeVellumDisplay
56
+ node_display_base_class = BaseNodeDisplay
56
57
 
57
58
  def serialize(self) -> JsonObject:
58
59
  input_variables: JsonArray = []
@@ -137,7 +138,7 @@ class VellumWorkflowDisplay(
137
138
  workflow_output_display.node_input_id,
138
139
  )
139
140
 
140
- source_node_display: Optional[BaseNodeVellumDisplay]
141
+ source_node_display: Optional[BaseNodeDisplay]
141
142
  first_rule = node_input.value.rules[0]
142
143
  if first_rule.type == "NODE_OUTPUT":
143
144
  source_node_id = UUID(first_rule.data.node_id)
@@ -170,15 +171,20 @@ class VellumWorkflowDisplay(
170
171
  )
171
172
 
172
173
  if source_node_display:
174
+ if isinstance(source_node_display, BaseNodeVellumDisplay):
175
+ source_handle_id = source_node_display.get_source_handle_id(
176
+ port_displays=self.display_context.port_displays
177
+ )
178
+ else:
179
+ source_handle_id = source_node_display.get_node_port_display(
180
+ source_node_display._node.Ports.default
181
+ ).id
182
+
173
183
  synthetic_output_edges.append(
174
184
  {
175
185
  "id": str(workflow_output_display.edge_id),
176
186
  "source_node_id": str(source_node_display.node_id),
177
- "source_handle_id": str(
178
- source_node_display.get_source_handle_id(
179
- port_displays=self.display_context.port_displays
180
- )
181
- ),
187
+ "source_handle_id": str(source_handle_id),
182
188
  "target_node_id": str(workflow_output_display.node_id),
183
189
  "target_handle_id": str(workflow_output_display.target_handle_id),
184
190
  "type": "DEFAULT",
@@ -279,7 +285,7 @@ class VellumWorkflowDisplay(
279
285
  self,
280
286
  entrypoint: Type[BaseNode],
281
287
  workflow_display: WorkflowMetaVellumDisplay,
282
- node_displays: Dict[Type[BaseNode], BaseNodeVellumDisplay],
288
+ node_displays: Dict[Type[BaseNode], BaseNodeDisplay],
283
289
  overrides: Optional[EntrypointVellumDisplayOverrides] = None,
284
290
  ) -> EntrypointVellumDisplay:
285
291
  entrypoint_node_id = workflow_display.entrypoint_node_id
@@ -293,8 +299,12 @@ class VellumWorkflowDisplay(
293
299
  )
294
300
 
295
301
  entrypoint_target = get_unadorned_node(entrypoint)
296
- target_node_id = node_displays[entrypoint_target].node_id
297
- target_handle_id = node_displays[entrypoint_target].get_target_handle_id()
302
+ target_node_display = node_displays[entrypoint_target]
303
+ target_node_id = target_node_display.node_id
304
+ if isinstance(target_node_display, BaseNodeVellumDisplay):
305
+ target_handle_id = target_node_display.get_target_handle_id_by_source_node_id(entrypoint_node_id)
306
+ else:
307
+ target_handle_id = target_node_display.get_trigger_id()
298
308
 
299
309
  edge_display = self._generate_edge_display_from_source(
300
310
  entrypoint_node_id, source_handle_id, target_node_id, target_handle_id, overrides=edge_display_overrides
@@ -339,7 +349,7 @@ class VellumWorkflowDisplay(
339
349
  def _generate_edge_display(
340
350
  self,
341
351
  edge: Edge,
342
- node_displays: Dict[Type[BaseNode], BaseNodeVellumDisplay],
352
+ node_displays: Dict[Type[BaseNode], BaseNodeDisplay],
343
353
  port_displays: Dict[Port, PortDisplay],
344
354
  overrides: Optional[EdgeVellumDisplayOverrides] = None,
345
355
  ) -> EdgeVellumDisplay:
@@ -352,7 +362,11 @@ class VellumWorkflowDisplay(
352
362
 
353
363
  target_node_display = node_displays[target_node]
354
364
  target_node_id = target_node_display.node_id
355
- target_handle_id = target_node_display.get_target_handle_id_by_source_node_id(source_node_id)
365
+
366
+ if isinstance(target_node_display, BaseNodeVellumDisplay):
367
+ target_handle_id = target_node_display.get_target_handle_id_by_source_node_id(source_node_id)
368
+ else:
369
+ target_handle_id = target_node_display.get_trigger_id()
356
370
 
357
371
  return self._generate_edge_display_from_source(
358
372
  source_node_id, source_handle_id, target_node_id, target_handle_id, overrides