vellum-ai 1.4.2__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. vellum/__init__.py +14 -0
  2. vellum/client/__init__.py +3 -0
  3. vellum/client/core/client_wrapper.py +2 -2
  4. vellum/client/reference.md +160 -0
  5. vellum/client/resources/__init__.py +2 -0
  6. vellum/client/resources/integrations/__init__.py +4 -0
  7. vellum/client/resources/integrations/client.py +260 -0
  8. vellum/client/resources/integrations/raw_client.py +267 -0
  9. vellum/client/types/__init__.py +12 -0
  10. vellum/client/types/components_schemas_composio_execute_tool_request.py +5 -0
  11. vellum/client/types/components_schemas_composio_execute_tool_response.py +5 -0
  12. vellum/client/types/components_schemas_composio_tool_definition.py +5 -0
  13. vellum/client/types/composio_execute_tool_request.py +24 -0
  14. vellum/client/types/composio_execute_tool_response.py +24 -0
  15. vellum/client/types/composio_tool_definition.py +26 -0
  16. vellum/client/types/vellum_error_code_enum.py +2 -0
  17. vellum/client/types/vellum_sdk_error.py +1 -0
  18. vellum/client/types/workflow_event_error.py +1 -0
  19. vellum/resources/integrations/__init__.py +3 -0
  20. vellum/resources/integrations/client.py +3 -0
  21. vellum/resources/integrations/raw_client.py +3 -0
  22. vellum/types/components_schemas_composio_execute_tool_request.py +3 -0
  23. vellum/types/components_schemas_composio_execute_tool_response.py +3 -0
  24. vellum/types/components_schemas_composio_tool_definition.py +3 -0
  25. vellum/types/composio_execute_tool_request.py +3 -0
  26. vellum/types/composio_execute_tool_response.py +3 -0
  27. vellum/types/composio_tool_definition.py +3 -0
  28. vellum/workflows/descriptors/utils.py +3 -0
  29. vellum/workflows/emitters/vellum_emitter.py +4 -1
  30. vellum/workflows/integrations/__init__.py +5 -0
  31. vellum/workflows/integrations/tests/__init__.py +0 -0
  32. vellum/workflows/integrations/tests/test_vellum_integration_service.py +225 -0
  33. vellum/workflows/integrations/vellum_integration_service.py +96 -0
  34. vellum/workflows/nodes/bases/base.py +24 -3
  35. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +5 -0
  36. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +2 -5
  37. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +38 -4
  38. vellum/workflows/runner/runner.py +132 -110
  39. vellum/workflows/utils/functions.py +29 -18
  40. vellum/workflows/utils/tests/test_functions.py +40 -0
  41. vellum/workflows/workflows/base.py +23 -5
  42. vellum/workflows/workflows/tests/test_base_workflow.py +99 -0
  43. {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.1.dist-info}/METADATA +1 -1
  44. {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.1.dist-info}/RECORD +64 -41
  45. vellum_ai-1.5.1.dist-info/entry_points.txt +4 -0
  46. vellum_ee/assets/node-definitions.json +833 -0
  47. vellum_ee/scripts/generate_node_definitions.py +89 -0
  48. vellum_ee/workflows/display/nodes/base_node_display.py +6 -3
  49. vellum_ee/workflows/display/nodes/vellum/api_node.py +4 -7
  50. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +19 -5
  51. vellum_ee/workflows/display/nodes/vellum/retry_node.py +2 -3
  52. vellum_ee/workflows/display/nodes/vellum/search_node.py +3 -6
  53. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  54. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -3
  55. vellum_ee/workflows/display/nodes/vellum/try_node.py +3 -4
  56. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +5 -11
  57. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_outputs_serialization.py +1 -1
  58. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +1 -1
  59. vellum_ee/workflows/display/types.py +3 -3
  60. vellum_ee/workflows/display/utils/expressions.py +10 -3
  61. vellum_ee/workflows/display/utils/vellum.py +9 -2
  62. vellum_ee/workflows/display/workflows/base_workflow_display.py +2 -2
  63. vellum_ai-1.4.2.dist-info/entry_points.txt +0 -3
  64. {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.1.dist-info}/LICENSE +0 -0
  65. {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,89 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Type
5
+
6
+ from vellum.workflows.nodes.bases.base import BaseNode
7
+ import vellum.workflows.nodes.displayable as displayable_module
8
+ from vellum.workflows.vellum_client import create_vellum_client
9
+ from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
10
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def create_display_context_with_client() -> WorkflowDisplayContext:
16
+ """Create a WorkflowDisplayContext with Vellum client for serialization."""
17
+ client = create_vellum_client()
18
+ return WorkflowDisplayContext(client=client, dry_run=True)
19
+
20
+
21
+ def get_all_displayable_node_classes() -> List[Type[BaseNode]]:
22
+ """Get all displayable node classes dynamically from the displayable module."""
23
+ node_classes = []
24
+ for class_name in displayable_module.__all__:
25
+ node_class = getattr(displayable_module, class_name)
26
+ node_classes.append(node_class)
27
+ return node_classes
28
+
29
+
30
+ def clean_node_definition(definition: Dict[str, Any]) -> Dict[str, Any]:
31
+ """Remove unwanted fields from a successfully serialized node definition."""
32
+ fields_to_remove = ["inputs", "data", "type", "adornments", "should_file_merge"]
33
+ cleaned = {k: v for k, v in definition.items() if k not in fields_to_remove}
34
+ return cleaned
35
+
36
+
37
+ def serialize_node_definition(
38
+ node_class: Type[BaseNode], display_context: WorkflowDisplayContext
39
+ ) -> Optional[Dict[str, Any]]:
40
+ """Serialize a single node definition, returning None if it fails."""
41
+ try:
42
+ display_class = get_node_display_class(node_class)
43
+ display_instance = display_class()
44
+ definition = display_instance.serialize(display_context)
45
+ return clean_node_definition(definition)
46
+ except Exception as e:
47
+ logger.info(f"Warning: Failed to serialize {node_class.__name__}: {e}")
48
+ return None
49
+
50
+
51
+ def main() -> None:
52
+ """Main function to generate node definitions."""
53
+ logger.info("Generating node definitions...")
54
+
55
+ display_context = create_display_context_with_client()
56
+ node_classes = get_all_displayable_node_classes()
57
+
58
+ successful_nodes = []
59
+ errors = []
60
+
61
+ for node_class in node_classes:
62
+ logger.info(f"Serializing {node_class.__name__}...")
63
+ definition = serialize_node_definition(node_class, display_context)
64
+
65
+ if definition is not None:
66
+ successful_nodes.append(definition)
67
+ else:
68
+ try:
69
+ display_class = get_node_display_class(node_class)
70
+ display_instance = display_class()
71
+ display_instance.serialize(display_context)
72
+ except Exception as e:
73
+ errors.append({"node": node_class.__name__, "error": str(e)})
74
+
75
+ result = {"nodes": successful_nodes, "errors": errors}
76
+
77
+ output_path = "ee/vellum_ee/assets/node-definitions.json"
78
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
79
+
80
+ with open(output_path, "w") as f:
81
+ json.dump(result, f, indent=2)
82
+
83
+ logger.info(
84
+ f"Generated {len(successful_nodes)} successful node definitions and {len(errors)} errors in {output_path}"
85
+ )
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
@@ -323,12 +323,12 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
323
323
  )
324
324
  return node_definition
325
325
 
326
- def get_node_output_display(self, output: OutputReference) -> Tuple[Type[BaseNode], NodeOutputDisplay]:
326
+ def get_node_output_display(self, output: OutputReference) -> NodeOutputDisplay:
327
327
  explicit_display = self.output_display.get(output)
328
328
  if explicit_display:
329
- return self._node, explicit_display
329
+ return explicit_display
330
330
 
331
- return (self._node, NodeOutputDisplay(id=uuid4_from_hash(f"{self.node_id}|{output.name}"), name=output.name))
331
+ return NodeOutputDisplay(id=uuid4_from_hash(f"{self.node_id}|{output.name}"), name=output.name)
332
332
 
333
333
  def get_node_port_display(self, port: Port) -> PortDisplay:
334
334
  overrides = self.port_displays.get(port)
@@ -347,6 +347,9 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
347
347
  if default_port in port_displays:
348
348
  return port_displays[default_port].id
349
349
 
350
+ if default_port:
351
+ return self.get_node_port_display(default_port).id
352
+
350
353
  first_port = next((port for port in unadorned_node.Ports), None)
351
354
  if not first_port:
352
355
  raise ValueError(f"Node {self._node.__name__} must have at least one port.")
@@ -1,8 +1,7 @@
1
1
  from uuid import UUID
2
- from typing import ClassVar, Dict, Generic, Optional, TypeVar, cast
2
+ from typing import ClassVar, Dict, Generic, Optional, TypeVar
3
3
 
4
4
  from vellum.workflows.nodes.displayable import APINode
5
- from vellum.workflows.references.output import OutputReference
6
5
  from vellum.workflows.types.core import JsonArray, JsonObject
7
6
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
8
7
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
@@ -178,11 +177,9 @@ class BaseAPINodeDisplay(BaseNodeDisplay[_APINodeType], Generic[_APINodeType]):
178
177
  ]
179
178
  inputs.extend(additional_header_inputs)
180
179
 
181
- _, text_output_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.text)]
182
- _, json_output_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.json)]
183
- _, status_code_output_display = display_context.global_node_output_displays[
184
- cast(OutputReference, node.Outputs.status_code)
185
- ]
180
+ text_output_display = self.get_node_output_display(node.Outputs.text)
181
+ json_output_display = self.get_node_output_display(node.Outputs.json)
182
+ status_code_output_display = self.get_node_output_display(node.Outputs.status_code)
186
183
 
187
184
  serialized_node: JsonObject = {
188
185
  "id": str(node_id),
@@ -7,6 +7,8 @@ from vellum.workflows.inputs.base import BaseInputs
7
7
  from vellum.workflows.nodes import InlineSubworkflowNode
8
8
  from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
9
9
  from vellum.workflows.types.core import JsonObject
10
+ from vellum.workflows.workflows.base import BaseWorkflow
11
+ from vellum_ee.workflows.display.exceptions import NodeValidationError
10
12
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
13
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
12
14
  from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
@@ -31,14 +33,25 @@ class BaseInlineSubworkflowNodeDisplay(
31
33
  node = self._node
32
34
  node_id = self.node_id
33
35
 
34
- node_inputs, workflow_inputs = self._generate_node_and_workflow_inputs(node_id, node, display_context)
36
+ subworkflow_class = raise_if_descriptor(node.subworkflow)
37
+ if subworkflow_class is None:
38
+ display_context.add_error(
39
+ NodeValidationError(
40
+ "InlineSubworkflowNode requires a subworkflow to be defined",
41
+ node_class_name=node.__class__.__name__,
42
+ )
43
+ )
44
+ subworkflow_class = BaseWorkflow
35
45
 
46
+ node_inputs, workflow_inputs = self._generate_node_and_workflow_inputs(
47
+ node_id, node, display_context, subworkflow_class
48
+ )
36
49
  subworkflow_display = get_workflow_display(
37
50
  base_display_class=display_context.workflow_display_class,
38
- workflow_class=raise_if_descriptor(node.subworkflow),
51
+ workflow_class=subworkflow_class,
39
52
  parent_display_context=display_context,
40
53
  )
41
- workflow_outputs = self._generate_workflow_outputs(node, subworkflow_display.display_context)
54
+ workflow_outputs = self._generate_workflow_outputs(node, subworkflow_display.display_context, subworkflow_class)
42
55
  serialized_subworkflow = subworkflow_display.serialize()
43
56
 
44
57
  return {
@@ -63,8 +76,8 @@ class BaseInlineSubworkflowNodeDisplay(
63
76
  node_id: UUID,
64
77
  node: Type[InlineSubworkflowNode],
65
78
  display_context: WorkflowDisplayContext,
79
+ subworkflow: Type[BaseWorkflow],
66
80
  ) -> Tuple[List[NodeInput], List[VellumVariable]]:
67
- subworkflow = raise_if_descriptor(node.subworkflow)
68
81
  subworkflow_inputs_class = subworkflow.get_inputs_class()
69
82
  subworkflow_inputs = raise_if_descriptor(node.subworkflow_inputs)
70
83
 
@@ -115,9 +128,10 @@ class BaseInlineSubworkflowNodeDisplay(
115
128
  self,
116
129
  node: Type[InlineSubworkflowNode],
117
130
  display_context: WorkflowDisplayContext,
131
+ subworkflow: Type[BaseWorkflow],
118
132
  ) -> List[VellumVariable]:
119
133
  workflow_outputs: List[VellumVariable] = []
120
- for output_descriptor in raise_if_descriptor(node.subworkflow).Outputs: # type: ignore[union-attr]
134
+ for output_descriptor in subworkflow.Outputs: # type: ignore[union-attr]
121
135
  workflow_output_display = display_context.workflow_output_displays[output_descriptor]
122
136
  output_type = infer_vellum_variable_type(output_descriptor)
123
137
  workflow_outputs.append(
@@ -1,7 +1,6 @@
1
1
  import inspect
2
- from typing import Any, Generic, Tuple, Type, TypeVar
2
+ from typing import Any, Generic, TypeVar
3
3
 
4
- from vellum.workflows.nodes.bases.base import BaseNode
5
4
  from vellum.workflows.nodes.core.retry_node.node import RetryNode
6
5
  from vellum.workflows.nodes.utils import ADORNMENT_MODULE_NAME
7
6
  from vellum.workflows.references.output import OutputReference
@@ -66,7 +65,7 @@ class BaseRetryNodeDisplay(BaseAdornmentNodeDisplay[_RetryNodeType], Generic[_Re
66
65
 
67
66
  return serialized_node
68
67
 
69
- def get_node_output_display(self, output: OutputReference) -> Tuple[Type[BaseNode], NodeOutputDisplay]:
68
+ def get_node_output_display(self, output: OutputReference) -> NodeOutputDisplay:
70
69
  inner_node = self._node.__wrapped_node__
71
70
  if not inner_node:
72
71
  return super().get_node_output_display(output)
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from uuid import UUID
3
- from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast
3
+ from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
4
4
 
5
5
  from vellum import (
6
6
  MetadataFilterConfigRequest,
@@ -9,7 +9,6 @@ from vellum import (
9
9
  )
10
10
  from vellum.workflows.nodes.displayable.bases.types import MetadataLogicalCondition, MetadataLogicalConditionGroup
11
11
  from vellum.workflows.nodes.displayable.search_node import SearchNode
12
- from vellum.workflows.references import OutputReference
13
12
  from vellum.workflows.types.core import JsonArray, JsonObject
14
13
  from vellum.workflows.utils.uuids import uuid4_from_hash
15
14
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
@@ -49,10 +48,8 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
49
48
  node_id = self.node_id
50
49
  node_inputs = self._generate_search_node_inputs(node_id, node, display_context)
51
50
 
52
- _, results_output_display = display_context.global_node_output_displays[
53
- cast(OutputReference, node.Outputs.results)
54
- ]
55
- _, text_output_display = display_context.global_node_output_displays[cast(OutputReference, node.Outputs.text)]
51
+ results_output_display = display_context.global_node_output_displays[node.Outputs.results]
52
+ text_output_display = display_context.global_node_output_displays[node.Outputs.text]
56
53
 
57
54
  return {
58
55
  "id": str(node_id),
@@ -50,7 +50,7 @@ class BaseTemplatingNodeDisplay(BaseNodeDisplay[_TemplatingNodeType], Generic[_T
50
50
  # Misc type ignore is due to `node.Outputs` being generic
51
51
  # https://app.shortcut.com/vellum/story/4784
52
52
  output_descriptor = node.Outputs.result # type: ignore [misc]
53
- _, output_display = display_context.global_node_output_displays[output_descriptor]
53
+ output_display = display_context.global_node_output_displays[output_descriptor]
54
54
  inferred_output_type = primitive_type_to_vellum_variable_type(output_descriptor)
55
55
 
56
56
  return {
@@ -121,9 +121,8 @@ def test_create_node_input_value_pointer_rules(
121
121
  ),
122
122
  },
123
123
  global_node_output_displays={
124
- MyNodeA.Outputs.output: (
125
- MyNodeA,
126
- NodeOutputDisplay(id=UUID("4b16a629-11a1-4b3f-a965-a57b872d13b8"), name="output"),
124
+ MyNodeA.Outputs.output: NodeOutputDisplay(
125
+ id=UUID("4b16a629-11a1-4b3f-a965-a57b872d13b8"), name="output"
127
126
  ),
128
127
  },
129
128
  global_node_displays={
@@ -1,8 +1,7 @@
1
1
  import inspect
2
2
  from uuid import UUID
3
- from typing import Any, ClassVar, Generic, Optional, Tuple, Type, TypeVar
3
+ from typing import Any, ClassVar, Generic, Optional, TypeVar
4
4
 
5
- from vellum.workflows.nodes.bases.base import BaseNode
6
5
  from vellum.workflows.nodes.core.try_node.node import TryNode
7
6
  from vellum.workflows.nodes.utils import ADORNMENT_MODULE_NAME
8
7
  from vellum.workflows.references.output import OutputReference
@@ -76,7 +75,7 @@ class BaseTryNodeDisplay(BaseAdornmentNodeDisplay[_TryNodeType], Generic[_TryNod
76
75
 
77
76
  return serialized_node
78
77
 
79
- def get_node_output_display(self, output: OutputReference) -> Tuple[Type[BaseNode], NodeOutputDisplay]:
78
+ def get_node_output_display(self, output: OutputReference) -> NodeOutputDisplay:
80
79
  inner_node = self._node.__wrapped_node__
81
80
  if not inner_node:
82
81
  return super().get_node_output_display(output)
@@ -84,7 +83,7 @@ class BaseTryNodeDisplay(BaseAdornmentNodeDisplay[_TryNodeType], Generic[_TryNod
84
83
  node_display_class = get_node_display_class(inner_node)
85
84
  node_display = node_display_class()
86
85
  if output.name == "error":
87
- return inner_node, NodeOutputDisplay(
86
+ return NodeOutputDisplay(
88
87
  id=self.error_output_id or uuid4_from_hash(f"{node_display.node_id}|error_output_id"),
89
88
  name="error",
90
89
  )
@@ -329,7 +329,7 @@ def test_serialize_node__node_output(serialize_node):
329
329
  global_workflow_input_displays={Inputs.input: WorkflowInputsDisplay(id=workflow_input_id)},
330
330
  global_node_displays={NodeWithOutput: NodeWithOutputDisplay()},
331
331
  global_node_output_displays={
332
- NodeWithOutput.Outputs.output: (NodeWithOutput, NodeOutputDisplay(id=node_output_id, name="output"))
332
+ NodeWithOutput.Outputs.output: NodeOutputDisplay(id=node_output_id, name="output")
333
333
  },
334
334
  )
335
335
 
@@ -521,14 +521,8 @@ def test_serialize_node__coalesce(serialize_node):
521
521
  CoalesceNodeFinal: CoalesceNodeFinalDisplay(),
522
522
  },
523
523
  global_node_output_displays={
524
- CoalesceNodeA.Outputs.output: (
525
- CoalesceNodeA,
526
- NodeOutputDisplay(id=coalesce_node_a_output_id, name="output"),
527
- ),
528
- CoalesceNodeB.Outputs.output: (
529
- CoalesceNodeB,
530
- NodeOutputDisplay(id=coalesce_node_b_output_id, name="output"),
531
- ),
524
+ CoalesceNodeA.Outputs.output: NodeOutputDisplay(id=coalesce_node_a_output_id, name="output"),
525
+ CoalesceNodeB.Outputs.output: NodeOutputDisplay(id=coalesce_node_b_output_id, name="output"),
532
526
  },
533
527
  )
534
528
 
@@ -602,7 +596,7 @@ def test_serialize_node__dataclass_with_node_output_reference(serialize_node):
602
596
  node_class=GenericNodeWithDataclass,
603
597
  global_node_displays={NodeWithOutput: NodeWithOutputDisplay()},
604
598
  global_node_output_displays={
605
- NodeWithOutput.Outputs.result: (NodeWithOutput, NodeOutputDisplay(id=node_output_id, name="result"))
599
+ NodeWithOutput.Outputs.result: NodeOutputDisplay(id=node_output_id, name="result")
606
600
  },
607
601
  )
608
602
 
@@ -634,7 +628,7 @@ def test_serialize_node__pydantic_with_node_output_reference(serialize_node):
634
628
  node_class=GenericNodeWithPydantic,
635
629
  global_node_displays={NodeWithOutput: NodeWithOutputDisplay()},
636
630
  global_node_output_displays={
637
- NodeWithOutput.Outputs.result: (NodeWithOutput, NodeOutputDisplay(id=node_output_id, name="result"))
631
+ NodeWithOutput.Outputs.result: NodeOutputDisplay(id=node_output_id, name="result")
638
632
  },
639
633
  )
640
634
 
@@ -127,7 +127,7 @@ def test_serialize_node__node_output_reference(serialize_node):
127
127
  global_workflow_input_displays={Inputs.input: WorkflowInputsDisplay(id=workflow_input_id)},
128
128
  global_node_displays={NodeWithOutput: NodeWithOutputDisplay()},
129
129
  global_node_output_displays={
130
- NodeWithOutput.Outputs.output: (NodeWithOutput, NodeOutputDisplay(id=node_output_id, name="output"))
130
+ NodeWithOutput.Outputs.output: NodeOutputDisplay(id=node_output_id, name="output")
131
131
  },
132
132
  )
133
133
 
@@ -292,7 +292,7 @@ def test_serialize_node__node_output_reference(serialize_node):
292
292
  global_workflow_input_displays={Inputs.input: WorkflowInputsDisplay(id=workflow_input_id)},
293
293
  global_node_displays={NodeWithOutput: NodeWithOutputDisplay()},
294
294
  global_node_output_displays={
295
- NodeWithOutput.Outputs.output: (NodeWithOutput, NodeOutputDisplay(id=node_output_id, name="output"))
295
+ NodeWithOutput.Outputs.output: NodeOutputDisplay(id=node_output_id, name="output")
296
296
  },
297
297
  )
298
298
 
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
28
28
  WorkflowInputsDisplays = Dict[WorkflowInputReference, WorkflowInputsDisplay]
29
29
  StateValueDisplays = Dict[StateValueReference, StateValueDisplay]
30
30
  NodeDisplays = Dict[Type[BaseNode], BaseNodeDisplay]
31
- NodeOutputDisplays = Dict[OutputReference, Tuple[Type[BaseNode], NodeOutputDisplay]]
31
+ NodeOutputDisplays = Dict[OutputReference, NodeOutputDisplay]
32
32
  EntrypointDisplays = Dict[Type[BaseNode], EntrypointDisplay]
33
33
  WorkflowOutputDisplays = Dict[BaseDescriptor, WorkflowOutputDisplay]
34
34
  EdgeDisplays = Dict[Tuple[Port, Type[BaseNode]], EdgeDisplay]
@@ -51,12 +51,12 @@ class WorkflowDisplayContext:
51
51
  workflow_output_displays: WorkflowOutputDisplays = field(default_factory=dict)
52
52
  edge_displays: EdgeDisplays = field(default_factory=dict)
53
53
  port_displays: PortDisplays = field(default_factory=dict)
54
+ dry_run: bool = False
54
55
  _errors: List[Exception] = field(default_factory=list)
55
56
  _invalid_nodes: List[Type[BaseNode]] = field(default_factory=list)
56
- _dry_run: bool = False
57
57
 
58
58
  def add_error(self, error: Exception, node: Optional[Type[BaseNode]] = None) -> None:
59
- if self._dry_run:
59
+ if self.dry_run:
60
60
  self._errors.append(error)
61
61
  return
62
62
 
@@ -41,7 +41,9 @@ from vellum.workflows.expressions.not_between import NotBetweenExpression
41
41
  from vellum.workflows.expressions.not_in import NotInExpression
42
42
  from vellum.workflows.expressions.or_ import OrExpression
43
43
  from vellum.workflows.expressions.parse_json import ParseJsonExpression
44
+ from vellum.workflows.nodes.bases.base import BaseNode
44
45
  from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
46
+ from vellum.workflows.nodes.utils import get_unadorned_node
45
47
  from vellum.workflows.references.constant import ConstantValueReference
46
48
  from vellum.workflows.references.environment_variable import EnvironmentVariableReference
47
49
  from vellum.workflows.references.execution_count import ExecutionCountReference
@@ -296,12 +298,17 @@ def serialize_value(executable_id: UUID, display_context: "WorkflowDisplayContex
296
298
  }
297
299
 
298
300
  if isinstance(value, OutputReference):
299
- upstream_node, output_display = display_context.global_node_output_displays[value]
300
- upstream_node_display = display_context.global_node_displays[upstream_node]
301
+ output_display = display_context.global_node_output_displays[value]
302
+
303
+ upstream_node_class = value.outputs_class.__parent_class__
304
+ if not issubclass(upstream_node_class, BaseNode):
305
+ raise ValueError(f"Output references must be to a node, not {upstream_node_class}")
306
+ unadorned_upstream_node_class = get_unadorned_node(upstream_node_class)
307
+ upstream_node = display_context.global_node_displays[unadorned_upstream_node_class]
301
308
 
302
309
  return {
303
310
  "type": "NODE_OUTPUT",
304
- "node_id": str(upstream_node_display.node_id),
311
+ "node_id": str(upstream_node.node_id),
305
312
  "node_output_id": str(output_display.id),
306
313
  }
307
314
 
@@ -8,6 +8,7 @@ from vellum.client.types.vellum_variable_type import VellumVariableType
8
8
  from vellum.workflows.descriptors.base import BaseDescriptor
9
9
  from vellum.workflows.nodes.bases.base import BaseNode
10
10
  from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
11
+ from vellum.workflows.nodes.utils import get_unadorned_node
11
12
  from vellum.workflows.references import OutputReference, WorkflowInputReference
12
13
  from vellum.workflows.references.execution_count import ExecutionCountReference
13
14
  from vellum.workflows.references.lazy import LazyReference
@@ -108,8 +109,14 @@ def create_node_input_value_pointer_rule(
108
109
 
109
110
  raise ValueError(f"Reference to outputs '{value.outputs_class.__qualname__}' is invalid.")
110
111
 
111
- upstream_node, output_display = display_context.global_node_output_displays[value]
112
- upstream_node_display = display_context.global_node_displays[upstream_node]
112
+ output_display = display_context.global_node_output_displays[value]
113
+
114
+ upstream_node_class = value.outputs_class.__parent_class__
115
+ if not issubclass(upstream_node_class, BaseNode):
116
+ raise ValueError(f"Output references must be to a node, not {upstream_node_class}")
117
+ unadorned_upstream_node_class = get_unadorned_node(upstream_node_class)
118
+ upstream_node_display = display_context.global_node_displays[unadorned_upstream_node_class]
119
+
113
120
  return NodeOutputPointer(
114
121
  data=NodeOutputData(node_id=str(upstream_node_display.node_id), output_id=str(output_display.id)),
115
122
  )
@@ -479,7 +479,7 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
479
479
  self,
480
480
  node: Type[BaseNode],
481
481
  node_display: BaseNodeDisplay,
482
- node_output_displays: Dict[OutputReference, Tuple[Type[BaseNode], NodeOutputDisplay]],
482
+ node_output_displays: Dict[OutputReference, NodeOutputDisplay],
483
483
  ):
484
484
  """This method recursively adds nodes wrapped in decorators to the node_output_displays dictionary."""
485
485
 
@@ -624,7 +624,7 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
624
624
  edge_displays=edge_displays,
625
625
  port_displays=port_displays,
626
626
  workflow_display_class=self.__class__,
627
- _dry_run=self._dry_run,
627
+ dry_run=self._dry_run,
628
628
  )
629
629
 
630
630
  def _generate_workflow_meta_display(self) -> WorkflowMetaDisplay:
@@ -1,3 +0,0 @@
1
- [console_scripts]
2
- vellum=vellum_cli:main
3
-