vellum-ai 0.14.38__py3-none-any.whl → 0.14.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. vellum/__init__.py +2 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +2 -0
  4. vellum/client/types/test_suite_run_progress.py +20 -0
  5. vellum/client/types/test_suite_run_read.py +3 -0
  6. vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
  7. vellum/client/types/workflow_execution_event_error_code.py +1 -0
  8. vellum/types/test_suite_run_progress.py +3 -0
  9. vellum/workflows/errors/types.py +1 -0
  10. vellum/workflows/events/tests/test_event.py +1 -0
  11. vellum/workflows/events/workflow.py +13 -3
  12. vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
  13. vellum/workflows/nodes/core/try_node/node.py +1 -2
  14. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
  15. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
  16. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +26 -0
  17. vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
  18. vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
  19. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
  20. vellum/workflows/nodes/utils.py +4 -2
  21. vellum/workflows/outputs/base.py +3 -2
  22. vellum/workflows/references/output.py +20 -0
  23. vellum/workflows/runner/runner.py +37 -17
  24. vellum/workflows/state/base.py +64 -19
  25. vellum/workflows/state/tests/test_state.py +31 -22
  26. vellum/workflows/types/stack.py +11 -0
  27. vellum/workflows/workflows/base.py +13 -18
  28. vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
  29. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
  30. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +82 -75
  31. vellum_cli/push.py +2 -5
  32. vellum_cli/tests/test_push.py +52 -0
  33. vellum_ee/workflows/display/base.py +14 -1
  34. vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
  35. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
  36. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
  37. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
  38. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
  39. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
  40. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
  41. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
  42. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
  43. vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
  45. vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
  46. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  47. vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
  48. vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
  49. vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
  50. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
  51. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
  52. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
  64. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
  65. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
  66. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
  70. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
  71. vellum_ee/workflows/display/types.py +5 -4
  72. vellum_ee/workflows/display/utils/exceptions.py +7 -0
  73. vellum_ee/workflows/display/utils/registry.py +37 -0
  74. vellum_ee/workflows/display/utils/vellum.py +2 -1
  75. vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
  76. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
  77. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
  78. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
  79. vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
  80. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
  81. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
  82. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
@@ -133,6 +133,58 @@ def test_push__happy_path(mock_module, vellum_client, base_command):
133
133
  assert extracted_files["workflow.py"] == workflow_py_file_content
134
134
 
135
135
 
136
+ def test_push__no_config__module_found(mock_module, vellum_client):
137
+ # GIVEN no config file set
138
+ temp_dir = mock_module.temp_dir
139
+ module = mock_module.module
140
+ mock_module.set_pyproject_toml({"workflows": []})
141
+
142
+ # AND a workflow exists in the module successfully
143
+ workflow_py_file_content = _ensure_workflow_py(temp_dir, module)
144
+
145
+ # AND the push API call returns successfully
146
+ new_workflow_sandbox_id = str(uuid4())
147
+ vellum_client.workflows.push.return_value = WorkflowPushResponse(
148
+ workflow_sandbox_id=new_workflow_sandbox_id,
149
+ )
150
+
151
+ # WHEN calling `vellum push`
152
+ runner = CliRunner()
153
+ result = runner.invoke(cli_main, ["push", module])
154
+
155
+ # THEN it should successfully push the workflow
156
+ assert result.exit_code == 0, result.stdout
157
+
158
+ # Get the last part of the module path and format it
159
+ expected_artifact_name = f"{mock_module.module.replace('.', '__')}.tar.gz"
160
+
161
+ # AND we should have called the push API with the correct args
162
+ vellum_client.workflows.push.assert_called_once()
163
+ call_args = vellum_client.workflows.push.call_args.kwargs
164
+ assert json.loads(call_args["exec_config"])["workflow_raw_data"]["definition"]["name"] == "ExampleWorkflow"
165
+ assert call_args["workflow_sandbox_id"] is None
166
+ assert "deplyment_config" not in call_args
167
+
168
+ # AND we should have pushed the correct artifact
169
+ assert call_args["artifact"].name == expected_artifact_name
170
+ extracted_files = _extract_tar_gz(call_args["artifact"].read())
171
+ assert extracted_files["workflow.py"] == workflow_py_file_content
172
+
173
+ # AND there should be a new entry in the lock file
174
+ with open(os.path.join(temp_dir, "vellum.lock.json")) as f:
175
+ lock_file_content = json.load(f)
176
+ assert lock_file_content["workflows"][0] == {
177
+ "container_image_name": None,
178
+ "container_image_tag": None,
179
+ "deployments": [],
180
+ "ignore": None,
181
+ "module": module,
182
+ "target_directory": None,
183
+ "workflow_sandbox_id": new_workflow_sandbox_id,
184
+ "workspace": "default",
185
+ }
186
+
187
+
136
188
  @pytest.mark.parametrize(
137
189
  "base_command",
138
190
  [
@@ -1,10 +1,12 @@
1
1
  from dataclasses import dataclass, field
2
2
  from uuid import UUID
3
- from typing import Optional
3
+ from typing import Optional, Type
4
4
 
5
5
  from pydantic import Field
6
6
 
7
7
  from vellum.client.core.pydantic_utilities import UniversalBaseModel
8
+ from vellum.workflows.utils.uuids import uuid4_from_hash
9
+ from vellum.workflows.workflows.base import BaseWorkflow
8
10
  from vellum_ee.workflows.display.editor.types import NodeDisplayData
9
11
 
10
12
 
@@ -25,6 +27,17 @@ class WorkflowMetaDisplay:
25
27
  entrypoint_node_display: NodeDisplayData = Field(default_factory=NodeDisplayData)
26
28
  display_data: WorkflowDisplayData = field(default_factory=WorkflowDisplayData)
27
29
 
30
+ @classmethod
31
+ def get_default(cls, workflow: Type[BaseWorkflow]) -> "WorkflowMetaDisplay":
32
+ entrypoint_node_id = uuid4_from_hash(f"{workflow.__id__}|entrypoint_node_id")
33
+ entrypoint_node_source_handle_id = uuid4_from_hash(f"{workflow.__id__}|entrypoint_node_source_handle_id")
34
+
35
+ return WorkflowMetaDisplay(
36
+ entrypoint_node_id=entrypoint_node_id,
37
+ entrypoint_node_source_handle_id=entrypoint_node_source_handle_id,
38
+ entrypoint_node_display=NodeDisplayData(),
39
+ )
40
+
28
41
 
29
42
  @dataclass
30
43
  class WorkflowMetaDisplayOverrides(WorkflowMetaDisplay):
@@ -21,14 +21,31 @@ from vellum.client.types.code_resource_definition import CodeResourceDefinition
21
21
  from vellum.workflows import BaseWorkflow
22
22
  from vellum.workflows.constants import undefined
23
23
  from vellum.workflows.descriptors.base import BaseDescriptor
24
+ from vellum.workflows.expressions.accessor import AccessorExpression
25
+ from vellum.workflows.expressions.and_ import AndExpression
26
+ from vellum.workflows.expressions.begins_with import BeginsWithExpression
24
27
  from vellum.workflows.expressions.between import BetweenExpression
28
+ from vellum.workflows.expressions.coalesce_expression import CoalesceExpression
29
+ from vellum.workflows.expressions.contains import ContainsExpression
30
+ from vellum.workflows.expressions.does_not_begin_with import DoesNotBeginWithExpression
31
+ from vellum.workflows.expressions.does_not_contain import DoesNotContainExpression
32
+ from vellum.workflows.expressions.does_not_end_with import DoesNotEndWithExpression
33
+ from vellum.workflows.expressions.ends_with import EndsWithExpression
34
+ from vellum.workflows.expressions.equals import EqualsExpression
35
+ from vellum.workflows.expressions.greater_than import GreaterThanExpression
36
+ from vellum.workflows.expressions.greater_than_or_equal_to import GreaterThanOrEqualToExpression
37
+ from vellum.workflows.expressions.in_ import InExpression
25
38
  from vellum.workflows.expressions.is_nil import IsNilExpression
26
39
  from vellum.workflows.expressions.is_not_nil import IsNotNilExpression
27
40
  from vellum.workflows.expressions.is_not_null import IsNotNullExpression
28
41
  from vellum.workflows.expressions.is_not_undefined import IsNotUndefinedExpression
29
42
  from vellum.workflows.expressions.is_null import IsNullExpression
30
43
  from vellum.workflows.expressions.is_undefined import IsUndefinedExpression
44
+ from vellum.workflows.expressions.less_than import LessThanExpression
45
+ from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
31
46
  from vellum.workflows.expressions.not_between import NotBetweenExpression
47
+ from vellum.workflows.expressions.not_in import NotInExpression
48
+ from vellum.workflows.expressions.or_ import OrExpression
32
49
  from vellum.workflows.expressions.parse_json import ParseJsonExpression
33
50
  from vellum.workflows.nodes.bases.base import BaseNode
34
51
  from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
@@ -50,7 +67,9 @@ from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_var
50
67
  from vellum_ee.workflows.display.editor.types import NodeDisplayData
51
68
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
52
69
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay, PortDisplayOverrides
70
+ from vellum_ee.workflows.display.utils.exceptions import UnsupportedSerializationException
53
71
  from vellum_ee.workflows.display.utils.expressions import get_child_descriptor
72
+ from vellum_ee.workflows.display.utils.registry import register_node_display_class
54
73
  from vellum_ee.workflows.display.utils.vellum import convert_descriptor_to_operator
55
74
 
56
75
  if TYPE_CHECKING:
@@ -121,9 +140,6 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
121
140
  # Once all nodes are Generic Nodes, we may replace this with a trigger_id or trigger attribute
122
141
  target_handle_id: ClassVar[Optional[UUID]] = None
123
142
 
124
- # Used to store the mapping between node types and their display classes
125
- _node_display_registry: Dict[Type[NodeType], Type["BaseNodeDisplay"]] = {}
126
-
127
143
  def serialize(self, display_context: "WorkflowDisplayContext", **kwargs: Any) -> JsonObject:
128
144
  node = self._node
129
145
  node_id = self.node_id
@@ -149,7 +165,7 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
149
165
  adornments = kwargs.get("adornments", None)
150
166
  wrapped_node = get_wrapped_node(node)
151
167
  if wrapped_node is not None:
152
- display_class = get_node_display_class(BaseNodeDisplay, wrapped_node)
168
+ display_class = get_node_display_class(wrapped_node)
153
169
 
154
170
  adornment: JsonObject = {
155
171
  "id": str(node_id),
@@ -280,10 +296,6 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
280
296
 
281
297
  return self.get_target_handle_id()
282
298
 
283
- @classmethod
284
- def get_from_node_display_registry(cls, node_class: Type[NodeType]) -> Optional[Type["BaseNodeDisplay"]]:
285
- return cls._node_display_registry.get(node_class)
286
-
287
299
  @cached_property
288
300
  def node_id(self) -> UUID:
289
301
  """Can be overridden as a class attribute to specify a custom node id."""
@@ -344,14 +356,12 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
344
356
 
345
357
  def __init_subclass__(cls, **kwargs: Any) -> None:
346
358
  super().__init_subclass__(**kwargs)
347
- if not cls._node_display_registry:
348
- cls._node_display_registry[BaseNode] = BaseNodeDisplay
349
359
 
350
360
  node_class = cls.infer_node_class()
351
361
  if node_class is BaseNode:
352
362
  return
353
363
 
354
- cls._node_display_registry[node_class] = cls
364
+ register_node_display_class(node_class=node_class, node_display_class=cls)
355
365
 
356
366
  def _get_generic_node_display_data(self) -> NodeDisplayData:
357
367
  explicit_value = self._get_explicit_node_display_attr("display_data", NodeDisplayData)
@@ -388,9 +398,29 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
388
398
  "lhs": lhs,
389
399
  "rhs": rhs,
390
400
  }
391
- else:
392
- lhs = self.serialize_value(display_context, condition._lhs) # type: ignore[attr-defined]
393
- rhs = self.serialize_value(display_context, condition._rhs) # type: ignore[attr-defined]
401
+ elif isinstance(
402
+ condition,
403
+ (
404
+ AndExpression,
405
+ BeginsWithExpression,
406
+ CoalesceExpression,
407
+ ContainsExpression,
408
+ DoesNotBeginWithExpression,
409
+ DoesNotContainExpression,
410
+ DoesNotEndWithExpression,
411
+ EndsWithExpression,
412
+ EqualsExpression,
413
+ GreaterThanExpression,
414
+ GreaterThanOrEqualToExpression,
415
+ InExpression,
416
+ LessThanExpression,
417
+ LessThanOrEqualToExpression,
418
+ NotInExpression,
419
+ OrExpression,
420
+ ),
421
+ ):
422
+ lhs = self.serialize_value(display_context, condition._lhs)
423
+ rhs = self.serialize_value(display_context, condition._rhs)
394
424
 
395
425
  return {
396
426
  "type": "BINARY_EXPRESSION",
@@ -398,6 +428,15 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
398
428
  "operator": convert_descriptor_to_operator(condition),
399
429
  "rhs": rhs,
400
430
  }
431
+ elif isinstance(condition, AccessorExpression):
432
+ return {
433
+ "type": "BINARY_EXPRESSION",
434
+ "lhs": self.serialize_value(display_context, condition._base),
435
+ "operator": "accessField",
436
+ "rhs": self.serialize_value(display_context, condition._field),
437
+ }
438
+
439
+ raise UnsupportedSerializationException(f"Unsupported condition type: {condition.__class__.__name__}")
401
440
 
402
441
  def serialize_value(self, display_context: "WorkflowDisplayContext", value: Any) -> JsonObject:
403
442
  if isinstance(value, ConstantValueReference):
@@ -458,3 +497,6 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
458
497
  # If it's not any of the references we know about,
459
498
  # then try to serialize it as a nested value
460
499
  return self.serialize_condition(display_context, value)
500
+
501
+
502
+ register_node_display_class(node_class=BaseNode, node_display_class=BaseNodeDisplay)
@@ -1,33 +1,27 @@
1
1
  import types
2
2
  from uuid import UUID
3
- from typing import TYPE_CHECKING, Any, Dict, Optional, Type
3
+ from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar
4
4
 
5
5
  from vellum.workflows.descriptors.base import BaseDescriptor
6
6
  from vellum.workflows.types.generics import NodeType
7
7
  from vellum.workflows.utils.uuids import uuid4_from_hash
8
+ from vellum_ee.workflows.display.utils.registry import get_from_node_display_registry
8
9
 
9
10
  if TYPE_CHECKING:
10
11
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
12
 
12
13
 
13
- def get_node_display_class(
14
- base_class: Type["BaseNodeDisplay"], node_class: Type[NodeType], root_node_class: Optional[Type[NodeType]] = None
15
- ) -> Type["BaseNodeDisplay"]:
16
- node_display_class = base_class.get_from_node_display_registry(node_class)
14
+ def get_node_display_class(node_class: Type[NodeType]) -> Type["BaseNodeDisplay"]:
15
+ node_display_class = get_from_node_display_registry(node_class)
17
16
  if node_display_class:
18
- if not issubclass(node_display_class, base_class):
19
- raise TypeError(
20
- f"Expected to find a subclass of '{base_class.__name__}' for node class '{node_class.__name__}'"
21
- )
22
-
23
17
  return node_display_class
24
18
 
25
- base_node_display_class = get_node_display_class(
26
- base_class, node_class.__bases__[0], node_class if root_node_class is None else root_node_class
27
- )
19
+ base_node_display_class = get_node_display_class(node_class.__bases__[0])
28
20
 
21
+ # mypy gets upset at dynamic TypeVar's, but it's technically allowed by python
22
+ _NodeClassType = TypeVar(f"_{node_class.__name__}Type", bound=node_class) # type: ignore[misc]
29
23
  # `base_node_display_class` is always a Generic class, so it's safe to index into it
30
- NodeDisplayBaseClass = base_node_display_class[node_class] # type: ignore[index]
24
+ NodeDisplayBaseClass = base_node_display_class[_NodeClassType] # type: ignore[index]
31
25
 
32
26
  def _get_node_input_ids_by_ref(path: str, inst: Any):
33
27
  if isinstance(inst, dict):
@@ -51,7 +45,7 @@ def get_node_display_class(
51
45
 
52
46
  NodeDisplayClass = types.new_class(
53
47
  f"{node_class.__name__}Display",
54
- bases=(NodeDisplayBaseClass,),
48
+ bases=(NodeDisplayBaseClass, Generic[_NodeClassType]),
55
49
  exec_body=exec_body,
56
50
  )
57
51
 
@@ -2,7 +2,11 @@ import pytest
2
2
  from uuid import UUID
3
3
 
4
4
  from vellum.workflows.nodes.bases import BaseNode
5
+ from vellum.workflows.ports.port import Port
6
+ from vellum.workflows.references.constant import ConstantValueReference
5
7
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
8
+ from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
9
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
6
10
 
7
11
 
8
12
  @pytest.fixture
@@ -46,3 +50,35 @@ def test_get_id(node_info):
46
50
 
47
51
  assert node_display().node_id == expected_id
48
52
  assert node_display.infer_node_class().__id__ == expected_id
53
+
54
+
55
+ def test_serialize_condition__accessor_expression():
56
+ # GIVEN a node with an accessor expression in a Port
57
+ class MyNode(BaseNode):
58
+ class Ports(BaseNode.Ports):
59
+ foo = Port.on_if(ConstantValueReference({"hello": "world"})["hello"])
60
+
61
+ # WHEN we serialize the node
62
+ node_display_class = get_node_display_class(MyNode)
63
+ data = node_display_class().serialize(WorkflowDisplayContext())
64
+
65
+ # THEN the condition should be serialized correctly
66
+ assert data["ports"] == [
67
+ {
68
+ "id": "7de6ea94-7f6c-475e-8f38-ec8ac317fd19",
69
+ "name": "foo",
70
+ "type": "IF",
71
+ "expression": {
72
+ "type": "BINARY_EXPRESSION",
73
+ "lhs": {
74
+ "type": "CONSTANT_VALUE",
75
+ "value": {
76
+ "type": "JSON",
77
+ "value": {"hello": "world"},
78
+ },
79
+ },
80
+ "operator": "accessField",
81
+ "rhs": {"type": "CONSTANT_VALUE", "value": {"type": "STRING", "value": "hello"}},
82
+ },
83
+ }
84
+ ]
@@ -14,6 +14,7 @@ from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeV
14
14
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
15
15
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
16
16
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
17
+ from vellum_ee.workflows.display.utils.registry import register_node_display_class
17
18
 
18
19
  _BaseAdornmentNodeType = TypeVar("_BaseAdornmentNodeType", bound=BaseAdornmentNode)
19
20
 
@@ -31,7 +32,7 @@ def _recursively_replace_wrapped_node(node_class: Type[BaseNode], wrapped_node_d
31
32
  # 1. The node display class' parameterized type
32
33
  original_base_node_display = get_original_base(wrapped_node_display_class)
33
34
  original_base_node_display.__args__ = (wrapped_node_class,)
34
- wrapped_node_display_class._node_display_registry[wrapped_node_class] = wrapped_node_display_class
35
+ register_node_display_class(node_class=wrapped_node_class, node_display_class=wrapped_node_display_class)
35
36
  wrapped_node_display_class.__annotate_node__()
36
37
 
37
38
  # 2. The node display class' output displays
@@ -89,7 +90,7 @@ class BaseAdornmentNodeDisplay(BaseNodeVellumDisplay[_BaseAdornmentNodeType], Ge
89
90
  "Unable to serialize standalone adornment nodes. Please use adornment nodes as a decorator."
90
91
  )
91
92
 
92
- wrapped_node_display_class = get_node_display_class(BaseNodeDisplay, wrapped_node)
93
+ wrapped_node_display_class = get_node_display_class(wrapped_node)
93
94
  wrapped_node_display = wrapped_node_display_class()
94
95
  additional_kwargs = get_additional_kwargs(wrapped_node_display.node_id) if get_additional_kwargs else {}
95
96
  serialized_wrapped_node = wrapped_node_display.serialize(display_context, **kwargs, **additional_kwargs)
@@ -9,7 +9,6 @@ from vellum.workflows.references.output import OutputReference
9
9
  from vellum.workflows.types.core import JsonArray, JsonObject
10
10
  from vellum.workflows.utils.uuids import uuid4_from_hash
11
11
  from vellum.workflows.workflows.base import BaseWorkflow
12
- from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
13
12
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
14
13
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
15
14
  from vellum_ee.workflows.display.nodes.vellum.base_adornment_node import BaseAdornmentNodeDisplay
@@ -72,7 +71,7 @@ class BaseRetryNodeDisplay(BaseAdornmentNodeDisplay[_RetryNodeType], Generic[_Re
72
71
  if not inner_node:
73
72
  return super().get_node_output_display(output)
74
73
 
75
- node_display_class = get_node_display_class(BaseNodeDisplay, inner_node)
74
+ node_display_class = get_node_display_class(inner_node)
76
75
  node_display = node_display_class()
77
76
 
78
77
  inner_output = getattr(inner_node.Outputs, output.name)
@@ -6,7 +6,6 @@ from vellum.workflows.nodes.displayable.code_execution_node.node import CodeExec
6
6
  from vellum.workflows.workflows.base import BaseWorkflow
7
7
  from vellum_ee.workflows.display.nodes.vellum.code_execution_node import BaseCodeExecutionNodeDisplay
8
8
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
9
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
10
9
 
11
10
 
12
11
  def _no_display_class(Node: Type[CodeExecutionNode]):
@@ -53,7 +52,7 @@ def test_serialize_node__code_node_inputs(GetDisplayClass, expected_input_id):
53
52
  GetDisplayClass(MyCodeExecutionNode)
54
53
 
55
54
  # WHEN the workflow is serialized
56
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=Workflow)
55
+ workflow_display = get_workflow_display(workflow_class=Workflow)
57
56
  serialized_workflow: dict = workflow_display.serialize()
58
57
 
59
58
  # THEN the node should properly serialize the inputs
@@ -4,7 +4,6 @@ from vellum.client.types.vellum_error import VellumError
4
4
  from vellum.workflows import BaseWorkflow
5
5
  from vellum.workflows.nodes.core.error_node.node import ErrorNode
6
6
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
7
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
8
7
 
9
8
 
10
9
  def test_error_node_display__serialize_with_vellum_error() -> None:
@@ -20,7 +19,7 @@ def test_error_node_display__serialize_with_vellum_error() -> None:
20
19
  graph = MyNode
21
20
 
22
21
  # WHEN we serialize the workflow
23
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=MyWorkflow)
22
+ workflow_display = get_workflow_display(workflow_class=MyWorkflow)
24
23
  serialized_workflow = cast(Dict[str, Any], workflow_display.serialize())
25
24
 
26
25
  # THEN the correct inputs should be serialized on the node
@@ -2,7 +2,6 @@ from vellum.workflows import BaseWorkflow
2
2
  from vellum.workflows.nodes.displayable.note_node.node import NoteNode
3
3
  from vellum_ee.workflows.display.nodes.vellum.note_node import BaseNoteNodeDisplay
4
4
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
5
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
6
5
 
7
6
 
8
7
  def test_serialize_node__note_node():
@@ -22,7 +21,7 @@ def test_serialize_node__note_node():
22
21
  graph = MyNoteNode
23
22
 
24
23
  # WHEN the workflow is serialized
25
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=Workflow)
24
+ workflow_display = get_workflow_display(workflow_class=Workflow)
26
25
  serialized_workflow: dict = workflow_display.serialize()
27
26
 
28
27
  # THEN the node should properly serialize the inputs
@@ -6,9 +6,9 @@ from vellum.workflows import BaseWorkflow
6
6
  from vellum.workflows.nodes import BaseNode
7
7
  from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
8
8
  from vellum.workflows.references.lazy import LazyReference
9
+ from vellum.workflows.state.base import BaseState
9
10
  from vellum_ee.workflows.display.nodes.vellum.inline_prompt_node import BaseInlinePromptNodeDisplay
10
11
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
11
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
12
12
 
13
13
 
14
14
  def test_serialize_node__lazy_reference_in_prompt_inputs():
@@ -27,7 +27,7 @@ def test_serialize_node__lazy_reference_in_prompt_inputs():
27
27
  graph = LazyReferencePromptNode >> OtherNode
28
28
 
29
29
  # WHEN the workflow is serialized
30
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=Workflow)
30
+ workflow_display = get_workflow_display(workflow_class=Workflow)
31
31
  serialized_workflow: dict = workflow_display.serialize()
32
32
 
33
33
  # THEN the node should properly serialize the attribute reference
@@ -103,7 +103,7 @@ def test_serialize_node__prompt_inputs(GetDisplayClass, expected_input_id):
103
103
  GetDisplayClass(MyPromptNode)
104
104
 
105
105
  # WHEN the workflow is serialized
106
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=Workflow)
106
+ workflow_display = get_workflow_display(workflow_class=Workflow)
107
107
  serialized_workflow: dict = workflow_display.serialize()
108
108
 
109
109
  # THEN the node should properly serialize the inputs
@@ -129,3 +129,55 @@ def test_serialize_node__prompt_inputs(GetDisplayClass, expected_input_id):
129
129
  },
130
130
  }
131
131
  ]
132
+
133
+
134
+ def test_serialize_node__prompt_inputs__state_reference():
135
+ # GIVEN a state definition
136
+ class MyState(BaseState):
137
+ foo: str
138
+
139
+ # AND a prompt node with inputs
140
+ class MyPromptNode(InlinePromptNode):
141
+ prompt_inputs = {"foo": MyState.foo, "bar": "baz"}
142
+ blocks = []
143
+ ml_model = "gpt-4o"
144
+
145
+ # AND a workflow with the prompt node
146
+ class Workflow(BaseWorkflow):
147
+ graph = MyPromptNode
148
+
149
+ # WHEN the workflow is serialized
150
+ workflow_display = get_workflow_display(workflow_class=Workflow)
151
+ serialized_workflow: dict = workflow_display.serialize()
152
+
153
+ # THEN the node should skip the state reference input rule
154
+ my_prompt_node = next(
155
+ node for node in serialized_workflow["workflow_raw_data"]["nodes"] if node["id"] == str(MyPromptNode.__id__)
156
+ )
157
+
158
+ assert my_prompt_node["inputs"] == [
159
+ {
160
+ "id": "e47e0a80-afbb-4888-b06b-8dc78edd8572",
161
+ "key": "foo",
162
+ "value": {
163
+ "rules": [],
164
+ "combinator": "OR",
165
+ },
166
+ },
167
+ {
168
+ "id": "3750feb9-5d5c-4150-b62d-a9924f466888",
169
+ "key": "bar",
170
+ "value": {
171
+ "rules": [
172
+ {
173
+ "type": "CONSTANT_VALUE",
174
+ "data": {
175
+ "type": "STRING",
176
+ "value": "baz",
177
+ },
178
+ }
179
+ ],
180
+ "combinator": "OR",
181
+ },
182
+ },
183
+ ]
@@ -5,7 +5,6 @@ from vellum.workflows.errors.types import WorkflowErrorCode
5
5
  from vellum.workflows.nodes.bases.base import BaseNode
6
6
  from vellum.workflows.nodes.core.retry_node.node import RetryNode
7
7
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
8
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
9
8
 
10
9
 
11
10
  def test_retry_node_parameters():
@@ -21,7 +20,7 @@ def test_retry_node_parameters():
21
20
  graph = MyRetryNode
22
21
 
23
22
  # WHEN we serialize the workflow
24
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=MyWorkflow)
23
+ workflow_display = get_workflow_display(workflow_class=MyWorkflow)
25
24
  serialized_workflow = cast(Dict[str, Any], workflow_display.serialize())
26
25
 
27
26
  # THEN the correct inputs should be serialized on the node
@@ -6,7 +6,6 @@ from vellum.workflows import BaseWorkflow
6
6
  from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
7
7
  from vellum_ee.workflows.display.nodes.vellum.templating_node import BaseTemplatingNodeDisplay
8
8
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
9
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
10
9
 
11
10
 
12
11
  def _no_display_class(Node: Type[TemplatingNode]):
@@ -53,7 +52,7 @@ def test_serialize_node__templating_node_inputs(GetDisplayClass, expected_input_
53
52
  GetDisplayClass(MyTemplatingNode)
54
53
 
55
54
  # WHEN the workflow is serialized
56
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=Workflow)
55
+ workflow_display = get_workflow_display(workflow_class=Workflow)
57
56
  serialized_workflow: dict = workflow_display.serialize()
58
57
 
59
58
  # THEN the node should properly serialize the inputs
@@ -5,7 +5,6 @@ from vellum.workflows.nodes.bases.base import BaseNode
5
5
  from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
6
6
  from vellum.workflows.nodes.core.try_node.node import TryNode
7
7
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
8
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
9
8
 
10
9
 
11
10
  def test_try_node_display__serialize_with_error_output() -> None:
@@ -28,7 +27,7 @@ def test_try_node_display__serialize_with_error_output() -> None:
28
27
  graph = MyNode >> OtherNode
29
28
 
30
29
  # WHEN we serialize the workflow
31
- workflow_display = get_workflow_display(base_display_class=VellumWorkflowDisplay, workflow_class=MyWorkflow)
30
+ workflow_display = get_workflow_display(workflow_class=MyWorkflow)
32
31
  serialized_workflow = cast(Dict[str, Any], workflow_display.serialize())
33
32
 
34
33
  # THEN the correct inputs should be serialized on the node
@@ -22,7 +22,7 @@ from vellum_ee.workflows.display.utils.vellum import (
22
22
  NodeOutputPointer,
23
23
  )
24
24
  from vellum_ee.workflows.display.vellum import WorkflowInputsVellumDisplayOverrides, WorkflowMetaVellumDisplay
25
- from vellum_ee.workflows.display.workflows.vellum_workflow_display import VellumWorkflowDisplay
25
+ from vellum_ee.workflows.display.workflows.base_workflow_display import BaseWorkflowDisplay
26
26
 
27
27
 
28
28
  class Inputs(BaseInputs):
@@ -109,7 +109,7 @@ def test_create_node_input_value_pointer_rules(
109
109
  rules = create_node_input_value_pointer_rules(
110
110
  descriptor,
111
111
  WorkflowDisplayContext(
112
- workflow_display_class=VellumWorkflowDisplay,
112
+ workflow_display_class=BaseWorkflowDisplay,
113
113
  workflow_display=WorkflowMetaVellumDisplay(
114
114
  entrypoint_node_id=uuid4(),
115
115
  entrypoint_node_source_handle_id=uuid4(),
@@ -10,7 +10,6 @@ from vellum.workflows.references.output import OutputReference
10
10
  from vellum.workflows.types.core import JsonArray, JsonObject
11
11
  from vellum.workflows.utils.uuids import uuid4_from_hash
12
12
  from vellum.workflows.workflows.base import BaseWorkflow
13
- from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
14
13
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
15
14
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay
16
15
  from vellum_ee.workflows.display.nodes.vellum.base_adornment_node import BaseAdornmentNodeDisplay
@@ -82,7 +81,7 @@ class BaseTryNodeDisplay(BaseAdornmentNodeDisplay[_TryNodeType], Generic[_TryNod
82
81
  if not inner_node:
83
82
  return super().get_node_output_display(output)
84
83
 
85
- node_display_class = get_node_display_class(BaseNodeDisplay, inner_node)
84
+ node_display_class = get_node_display_class(inner_node)
86
85
  node_display = node_display_class()
87
86
  if output.name == "error":
88
87
  return inner_node, NodeOutputDisplay(
@@ -8,6 +8,7 @@ from vellum.workflows.references import NodeReference
8
8
  from vellum.workflows.references.lazy import LazyReference
9
9
  from vellum.workflows.utils.uuids import uuid4_from_hash
10
10
  from vellum_ee.workflows.display.types import WorkflowDisplayContext
11
+ from vellum_ee.workflows.display.utils.exceptions import UnsupportedSerializationException
11
12
  from vellum_ee.workflows.display.utils.expressions import get_child_descriptor
12
13
  from vellum_ee.workflows.display.utils.vellum import (
13
14
  ConstantValuePointer,
@@ -76,7 +77,12 @@ def create_node_input_value_pointer_rules(
76
77
  node_input_value_pointer_rules.extend(rhs_rules)
77
78
  else:
78
79
  # Non-CoalesceExpression case
79
- node_input_value_pointer_rules.append(create_node_input_value_pointer_rule(value, display_context))
80
+ try:
81
+ rule = create_node_input_value_pointer_rule(value, display_context)
82
+ except UnsupportedSerializationException:
83
+ return node_input_value_pointer_rules
84
+
85
+ node_input_value_pointer_rules.append(rule)
80
86
  else:
81
87
  pointer = create_pointer(value, pointer_type)
82
88
  node_input_value_pointer_rules.append(pointer)