vellum-ai 0.13.10__py3-none-any.whl → 0.13.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/errors/types.py +21 -0
  3. vellum/workflows/nodes/bases/base.py +2 -9
  4. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +14 -1
  5. vellum/workflows/nodes/displayable/bases/tests/test_utils.py +18 -0
  6. vellum/workflows/nodes/displayable/bases/utils.py +8 -1
  7. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +55 -0
  8. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.12.dist-info}/METADATA +1 -1
  9. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.12.dist-info}/RECORD +33 -33
  10. vellum_cli/__init__.py +9 -1
  11. vellum_cli/config.py +29 -1
  12. vellum_cli/push.py +24 -3
  13. vellum_cli/tests/conftest.py +3 -0
  14. vellum_cli/tests/test_pull.py +6 -0
  15. vellum_cli/tests/test_push.py +88 -1
  16. vellum_ee/workflows/display/nodes/base_node_display.py +118 -6
  17. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +16 -1
  18. vellum_ee/workflows/display/nodes/get_node_display_class.py +6 -4
  19. vellum_ee/workflows/display/nodes/vellum/__init__.py +0 -2
  20. vellum_ee/workflows/display/nodes/vellum/error_node.py +9 -3
  21. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +44 -0
  22. vellum_ee/workflows/display/nodes/vellum/try_node.py +3 -4
  23. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +1 -1
  24. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +5 -2
  25. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +187 -6
  26. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_outputs_serialization.py +1 -1
  27. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +1 -1
  28. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +15 -1
  29. vellum_ee/workflows/display/utils/vellum.py +3 -0
  30. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +27 -13
  31. vellum_ee/workflows/display/nodes/vellum/base_node.py +0 -121
  32. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.12.dist-info}/LICENSE +0 -0
  33. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.12.dist-info}/WHEEL +0 -0
  34. {vellum_ai-0.13.10.dist-info → vellum_ai-0.13.12.dist-info}/entry_points.txt +0 -0
@@ -124,6 +124,7 @@ def test_pull__sandbox_id_with_no_config(vellum_client):
124
124
  lock_data = json.loads(f.read())
125
125
  assert lock_data == {
126
126
  "version": "1.0",
127
+ "workspaces": [],
127
128
  "workflows": [
128
129
  {
129
130
  "module": "workflow_87654321",
@@ -132,6 +133,7 @@ def test_pull__sandbox_id_with_no_config(vellum_client):
132
133
  "deployments": [],
133
134
  "container_image_tag": None,
134
135
  "container_image_name": None,
136
+ "workspace": "default",
135
137
  }
136
138
  ],
137
139
  }
@@ -208,8 +210,10 @@ def test_pull__workflow_deployment_with_no_config(vellum_client):
208
210
  "deployments": [],
209
211
  "container_image_tag": None,
210
212
  "container_image_name": None,
213
+ "workspace": "default",
211
214
  }
212
215
  ],
216
+ "workspaces": [],
213
217
  }
214
218
 
215
219
 
@@ -449,6 +453,7 @@ def test_pull__sandbox_id_with_other_workflow_deployment_in_lock(vellum_client,
449
453
  ],
450
454
  "container_image_name": None,
451
455
  "container_image_tag": None,
456
+ "workspace": "default",
452
457
  },
453
458
  {
454
459
  "module": "workflow_87654321",
@@ -457,6 +462,7 @@ def test_pull__sandbox_id_with_other_workflow_deployment_in_lock(vellum_client,
457
462
  "deployments": [],
458
463
  "container_image_name": "test",
459
464
  "container_image_tag": "1.0",
465
+ "workspace": "default",
460
466
  },
461
467
  ]
462
468
 
@@ -3,6 +3,7 @@ import io
3
3
  import json
4
4
  import os
5
5
  import tarfile
6
+ from unittest import mock
6
7
  from uuid import uuid4
7
8
 
8
9
  from click.testing import CliRunner
@@ -194,7 +195,7 @@ def test_push__dry_run_option_returns_report(mock_module, vellum_client):
194
195
  from typing import Dict
195
196
  from vellum.workflows import BaseWorkflow
196
197
  from vellum.workflows.nodes import BaseNode
197
- from vellum_ee.workflows.display.nodes import BaseNodeDisplay
198
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
198
199
 
199
200
  class NotSupportedNode(BaseNode):
200
201
  pass
@@ -318,3 +319,89 @@ Files that were different between the original project and the generated artifac
318
319
  \x1b[0m
319
320
  """
320
321
  )
322
+
323
+
324
+ def test_push__workspace_option__uses_different_api_key(mock_module, vellum_client_class):
325
+ # GIVEN a single workflow configured
326
+ temp_dir = mock_module.temp_dir
327
+ module = mock_module.module
328
+ workflow_sandbox_id = mock_module.workflow_sandbox_id
329
+ set_pyproject_toml = mock_module.set_pyproject_toml
330
+
331
+ # AND a different workspace is set in the pyproject.toml
332
+ set_pyproject_toml(
333
+ {
334
+ "workflows": [
335
+ {
336
+ "module": module,
337
+ "workflow_sandbox_id": workflow_sandbox_id,
338
+ }
339
+ ],
340
+ "workspaces": [
341
+ {
342
+ "name": "my_other_workspace",
343
+ "api_key": "MY_OTHER_VELLUM_API_KEY",
344
+ }
345
+ ],
346
+ }
347
+ )
348
+
349
+ # AND the .env file has the other api key stored
350
+ with open(os.path.join(temp_dir, ".env"), "w") as f:
351
+ f.write(
352
+ """\
353
+ VELLUM_API_KEY=abcdef123456
354
+ MY_OTHER_VELLUM_API_KEY=aaabbbcccddd
355
+ """
356
+ )
357
+
358
+ # AND a workflow exists in the module successfully
359
+ base_dir = os.path.join(temp_dir, *module.split("."))
360
+ os.makedirs(base_dir, exist_ok=True)
361
+ workflow_py_file_content = """\
362
+ from vellum.workflows import BaseWorkflow
363
+
364
+ class ExampleWorkflow(BaseWorkflow):
365
+ pass
366
+ """
367
+ with open(os.path.join(temp_dir, *module.split("."), "workflow.py"), "w") as f:
368
+ f.write(workflow_py_file_content)
369
+
370
+ # AND the push API call returns a new workflow sandbox id
371
+ new_workflow_sandbox_id = str(uuid4())
372
+ vellum_client_class.return_value.workflows.push.return_value = WorkflowPushResponse(
373
+ workflow_sandbox_id=new_workflow_sandbox_id,
374
+ )
375
+
376
+ # WHEN calling `vellum push` on strict mode
377
+ runner = CliRunner()
378
+ result = runner.invoke(cli_main, ["push", module, "--workspace", "my_other_workspace"])
379
+
380
+ # THEN it should succeed
381
+ assert result.exit_code == 0, result.output
382
+
383
+ # AND we should have called the push API once
384
+ vellum_client_class.return_value.workflows.push.assert_called_once()
385
+
386
+ # AND the workflow sandbox id arg passed in should be `None`
387
+ call_args = vellum_client_class.return_value.workflows.push.call_args.kwargs
388
+ assert call_args["workflow_sandbox_id"] is None
389
+
390
+ # AND with the correct api key
391
+ vellum_client_class.assert_called_once_with(
392
+ api_key="aaabbbcccddd",
393
+ environment=mock.ANY,
394
+ )
395
+
396
+ # AND the vellum lock file should have been updated with the correct workspace
397
+ with open(os.path.join(temp_dir, "vellum.lock.json")) as f:
398
+ lock_file_content = json.load(f)
399
+ assert lock_file_content["workflows"][1] == {
400
+ "module": module,
401
+ "workflow_sandbox_id": new_workflow_sandbox_id,
402
+ "workspace": "my_other_workspace",
403
+ "container_image_name": None,
404
+ "container_image_tag": None,
405
+ "deployments": [],
406
+ "ignore": None,
407
+ }
@@ -16,6 +16,8 @@ from typing import (
16
16
  get_origin,
17
17
  )
18
18
 
19
+ from vellum.workflows import BaseWorkflow
20
+ from vellum.workflows.constants import UNDEF
19
21
  from vellum.workflows.descriptors.base import BaseDescriptor
20
22
  from vellum.workflows.expressions.between import BetweenExpression
21
23
  from vellum.workflows.expressions.is_nil import IsNilExpression
@@ -26,19 +28,24 @@ from vellum.workflows.expressions.is_null import IsNullExpression
26
28
  from vellum.workflows.expressions.is_undefined import IsUndefinedExpression
27
29
  from vellum.workflows.expressions.not_between import NotBetweenExpression
28
30
  from vellum.workflows.nodes.bases.base import BaseNode
31
+ from vellum.workflows.nodes.utils import get_wrapped_node
29
32
  from vellum.workflows.ports import Port
30
33
  from vellum.workflows.references import OutputReference
34
+ from vellum.workflows.references.constant import ConstantValueReference
31
35
  from vellum.workflows.references.execution_count import ExecutionCountReference
36
+ from vellum.workflows.references.lazy import LazyReference
32
37
  from vellum.workflows.references.vellum_secret import VellumSecretReference
33
38
  from vellum.workflows.references.workflow_input import WorkflowInputReference
34
- from vellum.workflows.types.core import JsonObject
39
+ from vellum.workflows.types.core import JsonArray, JsonObject
35
40
  from vellum.workflows.types.generics import NodeType
36
41
  from vellum.workflows.types.utils import get_original_base
37
42
  from vellum.workflows.utils.names import pascal_to_title_case
38
43
  from vellum.workflows.utils.uuids import uuid4_from_hash
44
+ from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_variable_type
45
+ from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
39
46
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay, PortDisplayOverrides
40
47
  from vellum_ee.workflows.display.utils.vellum import convert_descriptor_to_operator, primitive_to_vellum_value
41
- from vellum_ee.workflows.display.vellum import CodeResourceDefinition
48
+ from vellum_ee.workflows.display.vellum import CodeResourceDefinition, GenericNodeDisplayData
42
49
 
43
50
  if TYPE_CHECKING:
44
51
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
@@ -65,7 +72,97 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
65
72
  _node_display_registry: Dict[Type[NodeType], Type["BaseNodeDisplay"]] = {}
66
73
 
67
74
  def serialize(self, display_context: "WorkflowDisplayContext", **kwargs: Any) -> JsonObject:
68
- raise NotImplementedError(f"Serialization for nodes of type {self._node.__name__} is not supported.")
75
+ node = self._node
76
+ node_id = self.node_id
77
+
78
+ attributes: JsonArray = []
79
+ for attribute in node:
80
+ if inspect.isclass(attribute.instance) and issubclass(attribute.instance, BaseWorkflow):
81
+ # We don't need to serialize generic node attributes containing a subworkflow
82
+ continue
83
+
84
+ id = str(uuid4_from_hash(f"{node_id}|{attribute.name}"))
85
+ attributes.append(
86
+ {
87
+ "id": id,
88
+ "name": attribute.name,
89
+ "value": self.serialize_value(display_context, cast(BaseDescriptor, attribute.instance)),
90
+ }
91
+ )
92
+
93
+ adornments = kwargs.get("adornments", None)
94
+ wrapped_node = get_wrapped_node(node)
95
+ if wrapped_node is not None:
96
+ display_class = get_node_display_class(BaseNodeDisplay, wrapped_node)
97
+
98
+ adornment: JsonObject = {
99
+ "id": str(node_id),
100
+ "label": node.__qualname__,
101
+ "base": self.get_base().dict(),
102
+ "attributes": attributes,
103
+ }
104
+
105
+ existing_adornments = adornments if adornments is not None else []
106
+ return display_class().serialize(display_context, adornments=existing_adornments + [adornment])
107
+
108
+ ports: JsonArray = []
109
+ for port in node.Ports:
110
+ id = str(self.get_node_port_display(port).id)
111
+
112
+ if port._condition_type:
113
+ ports.append(
114
+ {
115
+ "id": id,
116
+ "name": port.name,
117
+ "type": port._condition_type.value,
118
+ "expression": (
119
+ self.serialize_condition(display_context, port._condition) if port._condition else None
120
+ ),
121
+ }
122
+ )
123
+ else:
124
+ ports.append(
125
+ {
126
+ "id": id,
127
+ "name": port.name,
128
+ "type": "DEFAULT",
129
+ }
130
+ )
131
+
132
+ outputs: JsonArray = []
133
+ for output in node.Outputs:
134
+ type = primitive_type_to_vellum_variable_type(output)
135
+ value = (
136
+ self.serialize_value(display_context, output.instance)
137
+ if output.instance is not None and output.instance != UNDEF
138
+ else None
139
+ )
140
+
141
+ outputs.append(
142
+ {
143
+ "id": str(uuid4_from_hash(f"{node_id}|{output.name}")),
144
+ "name": output.name,
145
+ "type": type,
146
+ "value": value,
147
+ }
148
+ )
149
+
150
+ return {
151
+ "id": str(node_id),
152
+ "label": node.__qualname__,
153
+ "type": "GENERIC",
154
+ "display_data": self._get_generic_node_display_data().dict(),
155
+ "base": self.get_base().dict(),
156
+ "definition": self.get_definition().dict(),
157
+ "trigger": {
158
+ "id": str(self.get_trigger_id()),
159
+ "merge_behavior": node.Trigger.merge_behavior.value,
160
+ },
161
+ "ports": ports,
162
+ "adornments": adornments,
163
+ "attributes": attributes,
164
+ "outputs": outputs,
165
+ }
69
166
 
70
167
  def get_base(self) -> CodeResourceDefinition:
71
168
  node = self._node
@@ -89,9 +186,6 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
89
186
  )
90
187
  return node_definition
91
188
 
92
- def get_trigger_id(self) -> UUID:
93
- return uuid4_from_hash(f"{self.node_id}|trigger")
94
-
95
189
  def get_node_output_display(self, output: OutputReference) -> Tuple[Type[BaseNode], NodeOutputDisplay]:
96
190
  explicit_display = self.output_display.get(output)
97
191
  if explicit_display:
@@ -110,6 +204,9 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
110
204
 
111
205
  return PortDisplay(id=port_id, node_id=self.node_id)
112
206
 
207
+ def get_trigger_id(self) -> UUID:
208
+ return uuid4_from_hash(f"{self.node_id}|trigger")
209
+
113
210
  @classmethod
114
211
  def get_from_node_display_registry(cls, node_class: Type[NodeType]) -> Type["BaseNodeDisplay"]:
115
212
  return cls._node_display_registry[node_class]
@@ -187,10 +284,19 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
187
284
 
188
285
  def __init_subclass__(cls, **kwargs: Any) -> None:
189
286
  super().__init_subclass__(**kwargs)
287
+ if not cls._node_display_registry:
288
+ cls._node_display_registry[BaseNode] = BaseNodeDisplay
190
289
 
191
290
  node_class = cls.infer_node_class()
291
+ if node_class is BaseNode:
292
+ return
293
+
192
294
  cls._node_display_registry[node_class] = cls
193
295
 
296
+ def _get_generic_node_display_data(self) -> GenericNodeDisplayData:
297
+ explicit_value = self._get_explicit_node_display_attr("display_data", GenericNodeDisplayData)
298
+ return explicit_value if explicit_value else GenericNodeDisplayData()
299
+
194
300
  def serialize_condition(self, display_context: "WorkflowDisplayContext", condition: BaseDescriptor) -> JsonObject:
195
301
  if isinstance(
196
302
  condition,
@@ -233,6 +339,12 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
233
339
  }
234
340
 
235
341
  def serialize_value(self, display_context: "WorkflowDisplayContext", value: BaseDescriptor) -> JsonObject:
342
+ if isinstance(value, ConstantValueReference):
343
+ return self.serialize_value(display_context, value._value)
344
+
345
+ if isinstance(value, LazyReference):
346
+ return self.serialize_value(display_context, value._get())
347
+
236
348
  if isinstance(value, WorkflowInputReference):
237
349
  workflow_input_display = display_context.global_workflow_input_displays[value]
238
350
  return {
@@ -7,7 +7,7 @@ from vellum.workflows.types.generics import NodeType
7
7
  from vellum.workflows.utils.uuids import uuid4_from_hash
8
8
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
9
9
  from vellum_ee.workflows.display.nodes.types import PortDisplay
10
- from vellum_ee.workflows.display.vellum import NodeDisplayData
10
+ from vellum_ee.workflows.display.vellum import NodeDisplayComment, NodeDisplayData
11
11
 
12
12
 
13
13
  class BaseNodeVellumDisplay(BaseNodeDisplay[NodeType]):
@@ -26,6 +26,21 @@ class BaseNodeVellumDisplay(BaseNodeDisplay[NodeType]):
26
26
 
27
27
  def get_display_data(self) -> NodeDisplayData:
28
28
  explicit_value = self._get_explicit_node_display_attr("display_data", NodeDisplayData)
29
+ docstring = self._node.__doc__
30
+
31
+ if explicit_value and explicit_value.comment and docstring:
32
+ comment = (
33
+ NodeDisplayComment(value=docstring, expanded=explicit_value.comment.expanded)
34
+ if explicit_value.comment.expanded
35
+ else NodeDisplayComment(value=docstring)
36
+ )
37
+ return NodeDisplayData(
38
+ position=explicit_value.position,
39
+ width=explicit_value.width,
40
+ height=explicit_value.height,
41
+ comment=comment,
42
+ )
43
+
29
44
  return explicit_value if explicit_value else NodeDisplayData()
30
45
 
31
46
  def get_target_handle_id(self) -> UUID:
@@ -1,13 +1,15 @@
1
1
  import types
2
- from typing import Optional, Type
2
+ from typing import TYPE_CHECKING, Optional, Type
3
3
 
4
4
  from vellum.workflows.types.generics import NodeType
5
- from vellum_ee.workflows.display.types import NodeDisplayType
5
+
6
+ if TYPE_CHECKING:
7
+ from vellum_ee.workflows.display.types import NodeDisplayType
6
8
 
7
9
 
8
10
  def get_node_display_class(
9
- base_class: Type[NodeDisplayType], node_class: Type[NodeType], root_node_class: Optional[Type[NodeType]] = None
10
- ) -> Type[NodeDisplayType]:
11
+ base_class: Type["NodeDisplayType"], node_class: Type[NodeType], root_node_class: Optional[Type[NodeType]] = None
12
+ ) -> Type["NodeDisplayType"]:
11
13
  try:
12
14
  node_display_class = base_class.get_from_node_display_registry(node_class)
13
15
  except KeyError:
@@ -1,5 +1,4 @@
1
1
  from .api_node import BaseAPINodeDisplay
2
- from .base_node import BaseNodeDisplay
3
2
  from .code_execution_node import BaseCodeExecutionNodeDisplay
4
3
  from .conditional_node import BaseConditionalNodeDisplay
5
4
  from .error_node import BaseErrorNodeDisplay
@@ -28,7 +27,6 @@ __all__ = [
28
27
  "BaseInlineSubworkflowNodeDisplay",
29
28
  "BaseMapNodeDisplay",
30
29
  "BaseMergeNodeDisplay",
31
- "BaseNodeDisplay",
32
30
  "BaseNoteNodeDisplay",
33
31
  "BasePromptDeploymentNodeDisplay",
34
32
  "BaseSearchNodeDisplay",
@@ -1,9 +1,10 @@
1
1
  from uuid import UUID
2
- from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar
2
+ from typing import ClassVar, Generic, Optional, TypeVar
3
3
 
4
4
  from vellum.workflows.nodes import ErrorNode
5
5
  from vellum.workflows.types.core import JsonObject
6
6
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
7
+ from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
7
8
  from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
8
9
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
9
10
 
@@ -12,7 +13,7 @@ _ErrorNodeType = TypeVar("_ErrorNodeType", bound=ErrorNode)
12
13
 
13
14
  class BaseErrorNodeDisplay(BaseNodeVellumDisplay[_ErrorNodeType], Generic[_ErrorNodeType]):
14
15
  error_output_id: ClassVar[Optional[UUID]] = None
15
- error_inputs_by_name: ClassVar[Dict[str, Any]] = {}
16
+
16
17
  name: ClassVar[str] = "error-node"
17
18
 
18
19
  def serialize(
@@ -21,6 +22,11 @@ class BaseErrorNodeDisplay(BaseNodeVellumDisplay[_ErrorNodeType], Generic[_Error
21
22
  node_id = self.node_id
22
23
  error_source_input_id = self.node_input_ids_by_name.get("error_source_input_id")
23
24
 
25
+ error_attribute = raise_if_descriptor(self._node.error)
26
+ input_values_by_name = {
27
+ "error_source_input_id": error_attribute,
28
+ }
29
+
24
30
  node_inputs = [
25
31
  create_node_input(
26
32
  node_id=node_id,
@@ -29,7 +35,7 @@ class BaseErrorNodeDisplay(BaseNodeVellumDisplay[_ErrorNodeType], Generic[_Error
29
35
  display_context=display_context,
30
36
  input_id=self.node_input_ids_by_name.get(variable_name),
31
37
  )
32
- for variable_name, variable_value in self.error_inputs_by_name.items()
38
+ for variable_name, variable_value in input_values_by_name.items()
33
39
  ]
34
40
 
35
41
  return {
@@ -0,0 +1,44 @@
1
+ from typing import Any, Dict, cast
2
+
3
+ from vellum.client.types.vellum_error import VellumError
4
+ from vellum.workflows import BaseWorkflow
5
+ from vellum.workflows.nodes.core.error_node.node import ErrorNode
6
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
7
+ from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
8
+
9
+
10
+ def test_error_node_display__serialize_with_vellum_error() -> None:
11
+ # GIVEN an Error Node with a VellumError
12
+ class MyNode(ErrorNode):
13
+ error = VellumError(
14
+ message="A bad thing happened",
15
+ code="USER_DEFINED_ERROR",
16
+ )
17
+
18
+ # AND a workflow referencing the two node
19
+ class MyWorkflow(BaseWorkflow):
20
+ graph = MyNode
21
+
22
+ # WHEN we serialize the workflow
23
+ workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=MyWorkflow)
24
+ serialized_workflow = cast(Dict[str, Any], workflow_display.serialize())
25
+
26
+ # THEN the correct inputs should be serialized on the node
27
+ serialized_node = next(
28
+ node for node in serialized_workflow["workflow_raw_data"]["nodes"] if node["id"] == str(MyNode.__id__)
29
+ )
30
+ assert serialized_node["inputs"][0]["value"] == {
31
+ "combinator": "OR",
32
+ "rules": [
33
+ {
34
+ "data": {
35
+ "type": "ERROR",
36
+ "value": {
37
+ "message": "A bad thing happened",
38
+ "code": "USER_DEFINED_ERROR",
39
+ },
40
+ },
41
+ "type": "CONSTANT_VALUE",
42
+ }
43
+ ],
44
+ }
@@ -13,7 +13,6 @@ from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeV
13
13
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
14
14
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
15
15
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
16
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay as GenericBaseNodeDisplay
17
16
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
18
17
 
19
18
  _TryNodeType = TypeVar("_TryNodeType", bound=TryNode)
@@ -36,14 +35,14 @@ class BaseTryNodeDisplay(BaseNodeVellumDisplay[_TryNodeType], Generic[_TryNodeTy
36
35
  inner_node = subworkflow.graph
37
36
  elif inner_node.__bases__[0] is BaseNode:
38
37
  # If the wrapped node is a generic node, we let generic node do adornment handling
39
- class TryBaseNodeDisplay(GenericBaseNodeDisplay[node]): # type: ignore[valid-type]
38
+ class TryBaseNodeDisplay(BaseNodeDisplay[node]): # type: ignore[valid-type]
40
39
  pass
41
40
 
42
41
  return TryBaseNodeDisplay().serialize(display_context)
43
42
 
44
43
  # We need the node display class of the underlying node because
45
44
  # it contains the logic for serializing the node and potential display overrides
46
- node_display_class = get_node_display_class(BaseNodeVellumDisplay, inner_node)
45
+ node_display_class = get_node_display_class(BaseNodeDisplay, inner_node)
47
46
  node_display = node_display_class()
48
47
 
49
48
  serialized_node = node_display.serialize(
@@ -70,7 +69,7 @@ class BaseTryNodeDisplay(BaseNodeVellumDisplay[_TryNodeType], Generic[_TryNodeTy
70
69
  if not inner_node:
71
70
  return super().get_node_output_display(output)
72
71
 
73
- node_display_class = get_node_display_class(BaseNodeVellumDisplay, inner_node)
72
+ node_display_class = get_node_display_class(BaseNodeDisplay, inner_node)
74
73
  node_display = node_display_class()
75
74
  if output.name == "error":
76
75
  return inner_node, NodeOutputDisplay(
@@ -8,9 +8,9 @@ from vellum.workflows.references.workflow_input import WorkflowInputReference
8
8
  from vellum.workflows.types.core import JsonObject
9
9
  from vellum.workflows.types.generics import NodeType
10
10
  from vellum_ee.workflows.display.base import WorkflowInputsDisplayType
11
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
12
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
12
13
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
13
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
14
14
  from vellum_ee.workflows.display.types import NodeDisplayType, WorkflowDisplayContext
15
15
  from vellum_ee.workflows.display.vellum import NodeDisplayData, WorkflowMetaVellumDisplay
16
16
  from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
@@ -8,8 +8,8 @@ from vellum.workflows.nodes.core.retry_node.node import RetryNode
8
8
  from vellum.workflows.nodes.core.try_node.node import TryNode
9
9
  from vellum.workflows.outputs.base import BaseOutputs
10
10
  from vellum_ee.workflows.display.base import WorkflowInputsDisplay
11
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
12
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
12
- from vellum_ee.workflows.display.nodes.vellum.base_node import BaseNodeDisplay
13
13
  from vellum_ee.workflows.display.nodes.vellum.try_node import BaseTryNodeDisplay
14
14
 
15
15
 
@@ -159,7 +159,10 @@ def test_serialize_node__try(serialize_node):
159
159
  {
160
160
  "id": "3344083c-a32c-4a32-920b-0fb5093448fa",
161
161
  "label": "TryNode",
162
- "base": {"name": "TryNode", "module": ["vellum", "workflows", "nodes", "core", "try_node", "node"]},
162
+ "base": {
163
+ "name": "TryNode",
164
+ "module": ["vellum", "workflows", "nodes", "core", "try_node", "node"],
165
+ },
163
166
  "attributes": [
164
167
  {
165
168
  "id": "ab2fbab0-e2a0-419b-b1ef-ce11ecf11e90",