vellum-ai 1.7.7__py3-none-any.whl → 1.7.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. vellum/client/core/client_wrapper.py +2 -2
  2. vellum/client/reference.md +16 -0
  3. vellum/client/resources/ad_hoc/raw_client.py +2 -2
  4. vellum/client/resources/integration_providers/client.py +20 -0
  5. vellum/client/resources/integration_providers/raw_client.py +20 -0
  6. vellum/client/types/integration_name.py +1 -0
  7. vellum/client/types/workflow_execution_fulfilled_body.py +1 -0
  8. vellum/workflows/errors/types.py +3 -0
  9. vellum/workflows/graph/graph.py +4 -1
  10. vellum/workflows/nodes/core/map_node/node.py +10 -0
  11. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +49 -1
  12. vellum/workflows/nodes/tests/test_utils.py +7 -1
  13. vellum/workflows/nodes/utils.py +1 -1
  14. vellum/workflows/references/__init__.py +2 -0
  15. vellum/workflows/references/trigger.py +83 -0
  16. vellum/workflows/runner/runner.py +17 -5
  17. vellum/workflows/state/base.py +49 -1
  18. vellum/workflows/triggers/__init__.py +2 -1
  19. vellum/workflows/triggers/base.py +140 -3
  20. vellum/workflows/triggers/integration.py +31 -26
  21. vellum/workflows/triggers/slack.py +101 -0
  22. vellum/workflows/triggers/tests/test_integration.py +55 -31
  23. vellum/workflows/triggers/tests/test_slack.py +180 -0
  24. vellum/workflows/utils/functions.py +1 -1
  25. vellum/workflows/workflows/base.py +2 -2
  26. {vellum_ai-1.7.7.dist-info → vellum_ai-1.7.9.dist-info}/METADATA +1 -1
  27. {vellum_ai-1.7.7.dist-info → vellum_ai-1.7.9.dist-info}/RECORD +53 -49
  28. vellum_ee/assets/node-definitions.json +91 -65
  29. vellum_ee/workflows/display/base.py +3 -0
  30. vellum_ee/workflows/display/nodes/base_node_display.py +46 -26
  31. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +1 -1
  32. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +1 -1
  33. vellum_ee/workflows/display/nodes/vellum/error_node.py +1 -1
  34. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +1 -1
  35. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +1 -1
  36. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +1 -1
  37. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +1 -1
  38. vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
  39. vellum_ee/workflows/display/nodes/vellum/merge_node.py +1 -1
  40. vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -1
  41. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +1 -1
  42. vellum_ee/workflows/display/nodes/vellum/search_node.py +1 -1
  43. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +1 -1
  44. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  45. vellum_ee/workflows/display/nodes/vellum/tests/test_api_node.py +34 -0
  46. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +8 -0
  47. vellum_ee/workflows/display/tests/workflow_serialization/test_manual_trigger_serialization.py +33 -70
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_slack_trigger_serialization.py +167 -0
  49. vellum_ee/workflows/display/utils/expressions.py +12 -0
  50. vellum_ee/workflows/display/workflows/base_workflow_display.py +23 -7
  51. {vellum_ai-1.7.7.dist-info → vellum_ai-1.7.9.dist-info}/LICENSE +0 -0
  52. {vellum_ai-1.7.7.dist-info → vellum_ai-1.7.9.dist-info}/WHEEL +0 -0
  53. {vellum_ai-1.7.7.dist-info → vellum_ai-1.7.9.dist-info}/entry_points.txt +0 -0
@@ -8,6 +8,7 @@ from typing import (
8
8
  Dict,
9
9
  ForwardRef,
10
10
  Generic,
11
+ List,
11
12
  Optional,
12
13
  Set,
13
14
  Tuple,
@@ -183,32 +184,13 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
183
184
  existing_adornments = adornments if adornments is not None else []
184
185
  return display_class().serialize(display_context, adornments=existing_adornments + [adornment])
185
186
 
186
- outputs: JsonArray = []
187
- for output in node.Outputs:
188
- type = primitive_type_to_vellum_variable_type(output)
189
- value = (
190
- serialize_value(node_id, display_context, output.instance)
191
- if output.instance is not None and output.instance is not undefined
192
- else None
193
- )
194
-
195
- outputs.append(
196
- {
197
- "id": str(uuid4_from_hash(f"{node_id}|{output.name}")),
198
- "name": output.name,
199
- "type": type,
200
- "value": value,
201
- }
202
- )
203
-
204
187
  return {
205
188
  "id": str(node_id),
206
189
  "label": self.label,
207
190
  "type": "GENERIC",
208
- **self.serialize_generic_fields(display_context),
209
191
  "adornments": adornments,
210
192
  "attributes": attributes,
211
- "outputs": outputs,
193
+ **self.serialize_generic_fields(display_context),
212
194
  }
213
195
 
214
196
  def serialize_ports(self, display_context: "WorkflowDisplayContext") -> JsonArray:
@@ -275,14 +257,49 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
275
257
 
276
258
  return attributes
277
259
 
278
- def serialize_generic_fields(self, display_context: "WorkflowDisplayContext") -> JsonObject:
260
+ def _serialize_outputs(self, display_context: "WorkflowDisplayContext") -> JsonArray:
261
+ """Generate outputs array from node output displays or node.Outputs."""
262
+ outputs: JsonArray = []
263
+ node = self._node
264
+
265
+ for output in node.Outputs:
266
+ output_type = primitive_type_to_vellum_variable_type(output)
267
+ value = (
268
+ serialize_value(self.node_id, display_context, output.instance)
269
+ if output.instance is not None and output.instance != undefined
270
+ else None
271
+ )
272
+
273
+ output_id = (
274
+ str(self.output_display[output].id)
275
+ if output in self.output_display
276
+ else str(uuid4_from_hash(f"{self.node_id}|{output.name}"))
277
+ )
278
+
279
+ outputs.append(
280
+ {
281
+ "id": output_id,
282
+ "name": output.name,
283
+ "type": output_type,
284
+ "value": value,
285
+ }
286
+ )
287
+
288
+ return outputs
289
+
290
+ def serialize_generic_fields(
291
+ self, display_context: "WorkflowDisplayContext", exclude: Optional[List[str]] = None
292
+ ) -> JsonObject:
279
293
  """Serialize generic fields that are common to all nodes."""
294
+ exclude = exclude or []
295
+
280
296
  result: JsonObject = {
281
- "display_data": self.get_display_data().dict(),
282
- "base": self.get_base().dict(),
283
- "definition": self.get_definition().dict(),
284
- "trigger": self.serialize_trigger(),
285
- "ports": self.serialize_ports(display_context),
297
+ "display_data": self.get_display_data().dict() if "display_data" not in exclude else None,
298
+ "base": self.get_base().dict() if "base" not in exclude else None,
299
+ "definition": self.get_definition().dict() if "definition" not in exclude else None,
300
+ "trigger": self.serialize_trigger() if "trigger" not in exclude else None,
301
+ "ports": self.serialize_ports(display_context) if "ports" not in exclude else None,
302
+ "outputs": self._serialize_outputs(display_context) if "outputs" not in exclude else None,
286
303
  }
287
304
 
288
305
  # Only include should_file_merge if there are custom methods defined
@@ -299,6 +316,9 @@ class BaseNodeDisplay(Generic[NodeType], metaclass=BaseNodeDisplayMeta):
299
316
  except Exception:
300
317
  pass
301
318
 
319
+ for key in exclude:
320
+ result.pop(key, None)
321
+
302
322
  return result
303
323
 
304
324
  def get_base(self) -> CodeResourceDefinition:
@@ -119,5 +119,5 @@ class BaseCodeExecutionNodeDisplay(BaseNodeDisplay[_CodeExecutionNodeType], Gene
119
119
  "output_id": str(self.output_id) if self.output_id else str(output_display.id),
120
120
  "log_output_id": str(self.log_output_id) if self.log_output_id else str(log_output_display.id),
121
121
  },
122
- **self.serialize_generic_fields(display_context),
122
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
123
123
  }
@@ -217,7 +217,7 @@ but the defined conditions have length {len(condition_ids)}"""
217
217
  "conditions": conditions, # type: ignore
218
218
  "version": "2",
219
219
  },
220
- **self.serialize_generic_fields(display_context),
220
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
221
221
  }
222
222
 
223
223
  def get_nested_rule_details_by_path(
@@ -47,7 +47,7 @@ class BaseErrorNodeDisplay(BaseNodeDisplay[_ErrorNodeType], Generic[_ErrorNodeTy
47
47
  "target_handle_id": str(self.get_target_handle_id()),
48
48
  "error_source_input_id": str(error_source_input_id),
49
49
  },
50
- **self.serialize_generic_fields(display_context),
50
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
51
51
  }
52
52
 
53
53
  if self.name:
@@ -45,7 +45,7 @@ class BaseFinalOutputNodeDisplay(BaseNodeDisplay[_FinalOutputNodeType], Generic[
45
45
  "node_input_id": str(node_input.id),
46
46
  },
47
47
  "inputs": [node_input.dict()],
48
- **self.serialize_generic_fields(display_context),
48
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
49
49
  "outputs": [
50
50
  {
51
51
  "id": str(self._get_output_id()),
@@ -45,5 +45,5 @@ class BaseGuardrailNodeDisplay(BaseNodeDisplay[_GuardrailNodeType], Generic[_Gua
45
45
  "metric_definition_id": str(raise_if_descriptor(node.metric_definition)),
46
46
  "release_tag": raise_if_descriptor(node.release_tag),
47
47
  },
48
- **self.serialize_generic_fields(display_context),
48
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
49
49
  }
@@ -110,7 +110,7 @@ class BaseInlinePromptNodeDisplay(BaseNodeDisplay[_InlinePromptNodeType], Generi
110
110
  },
111
111
  "ml_model_name": ml_model,
112
112
  },
113
- **self.serialize_generic_fields(display_context),
113
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
114
114
  "outputs": [
115
115
  {"id": str(json_display.id), "name": "json", "type": "JSON", "value": None},
116
116
  {"id": str(output_display.id), "name": "text", "type": "STRING", "value": None},
@@ -68,7 +68,7 @@ class BaseInlineSubworkflowNodeDisplay(
68
68
  "input_variables": [workflow_input.dict() for workflow_input in workflow_inputs],
69
69
  "output_variables": [workflow_output.dict() for workflow_output in workflow_outputs],
70
70
  },
71
- **self.serialize_generic_fields(display_context),
71
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
72
72
  }
73
73
 
74
74
  def _generate_node_and_workflow_inputs(
@@ -88,7 +88,7 @@ class BaseMapNodeDisplay(BaseAdornmentNodeDisplay[_MapNodeType], Generic[_MapNod
88
88
  "item_input_id": item_workflow_input_id,
89
89
  "index_input_id": index_workflow_input_id,
90
90
  },
91
- **self.serialize_generic_fields(display_context),
91
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
92
92
  }
93
93
 
94
94
  def _default_workflow_class(self) -> Type[BaseWorkflow]:
@@ -47,7 +47,7 @@ class BaseMergeNodeDisplay(BaseNodeDisplay[_MergeNodeType], Generic[_MergeNodeTy
47
47
  "target_handles": [{"id": str(target_handle_id)} for target_handle_id in target_handle_ids],
48
48
  "source_handle_id": str(self.get_source_handle_id(display_context.port_displays)),
49
49
  },
50
- **self.serialize_generic_fields(display_context),
50
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
51
51
  }
52
52
 
53
53
  def get_target_handle_ids(self) -> Optional[List[UUID]]:
@@ -25,5 +25,5 @@ class BaseNoteNodeDisplay(BaseNodeDisplay[_NoteNodeType], Generic[_NoteNodeType]
25
25
  "text": self.text,
26
26
  "style": self.style,
27
27
  },
28
- **self.serialize_generic_fields(display_context),
28
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
29
29
  }
@@ -71,7 +71,7 @@ class BasePromptDeploymentNodeDisplay(BaseNodeDisplay[_PromptDeploymentNodeType]
71
71
  "release_tag": raise_if_descriptor(node.release_tag),
72
72
  "ml_model_fallbacks": list(ml_model_fallbacks) if ml_model_fallbacks else None,
73
73
  },
74
- **self.serialize_generic_fields(display_context),
74
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
75
75
  "outputs": [
76
76
  {"id": str(json_display.id), "name": "json", "type": "JSON", "value": None},
77
77
  {"id": str(output_display.id), "name": "text", "type": "STRING", "value": None},
@@ -71,7 +71,7 @@ class BaseSearchNodeDisplay(BaseNodeDisplay[_SearchNodeType], Generic[_SearchNod
71
71
  "external_id_filters_node_input_id": str(node_inputs["external_id_filters"].id),
72
72
  "metadata_filters_node_input_id": str(node_inputs["metadata_filters"].id),
73
73
  },
74
- **self.serialize_generic_fields(display_context),
74
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
75
75
  }
76
76
 
77
77
  def _generate_search_node_inputs(
@@ -68,5 +68,5 @@ class BaseSubworkflowDeploymentNodeDisplay(
68
68
  "workflow_deployment_id": deployment_id,
69
69
  "release_tag": raise_if_descriptor(node.release_tag),
70
70
  },
71
- **self.serialize_generic_fields(display_context),
71
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
72
72
  }
@@ -66,5 +66,5 @@ class BaseTemplatingNodeDisplay(BaseNodeDisplay[_TemplatingNodeType], Generic[_T
66
66
  "template_node_input_id": str(template_node_input.id),
67
67
  "output_type": inferred_output_type,
68
68
  },
69
- **self.serialize_generic_fields(display_context),
69
+ **self.serialize_generic_fields(display_context, exclude=["outputs"]),
70
70
  }
@@ -29,3 +29,37 @@ def test_serialize_node__api_node_with_timeout():
29
29
  assert timeout_attribute["value"]["type"] == "CONSTANT_VALUE"
30
30
  assert timeout_attribute["value"]["value"]["type"] == "NUMBER"
31
31
  assert timeout_attribute["value"]["value"]["value"] == 30.0
32
+
33
+
34
+ def test_serialize_node__api_node_outputs():
35
+ """
36
+ Tests that API node serialization includes the outputs array.
37
+ """
38
+
39
+ class MyAPINode(APINode):
40
+ url = "https://api.example.com"
41
+ method = APIRequestMethod.GET
42
+
43
+ class Workflow(BaseWorkflow):
44
+ graph = MyAPINode
45
+
46
+ workflow_display = get_workflow_display(workflow_class=Workflow)
47
+ serialized_workflow: dict = workflow_display.serialize()
48
+
49
+ my_api_node = next(node for node in serialized_workflow["workflow_raw_data"]["nodes"] if node["type"] == "API")
50
+
51
+ assert "outputs" in my_api_node
52
+ outputs = my_api_node["outputs"]
53
+ assert len(outputs) == 4
54
+
55
+ text_output = next(output for output in outputs if output["name"] == "text")
56
+ assert text_output["type"] == "STRING"
57
+ assert text_output["value"] is None
58
+
59
+ json_output = next(output for output in outputs if output["name"] == "json")
60
+ assert json_output["type"] == "JSON"
61
+ assert json_output["value"] is None
62
+
63
+ status_code_output = next(output for output in outputs if output["name"] == "status_code")
64
+ assert status_code_output["type"] == "NUMBER"
65
+ assert status_code_output["value"] is None
@@ -8,6 +8,8 @@ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class imp
8
8
 
9
9
  from tests.workflows.basic_api_node.workflow import SimpleAPIWorkflow
10
10
 
11
+ # 0d76e1e1-3a4b-4eb4-a606-f73d62c -> 12e4a99d-883d-4da5-aa51-35817d94013e
12
+
11
13
 
12
14
  def test_serialize_workflow(vellum_client):
13
15
  # GIVEN a Workflow that uses a vellum API node
@@ -208,6 +210,12 @@ def test_serialize_workflow(vellum_client):
208
210
  "merge_behavior": "AWAIT_ANY",
209
211
  },
210
212
  "ports": [{"id": "7c33b4d3-9204-4bd5-9371-80ee34f83073", "name": "default", "type": "DEFAULT"}],
213
+ "outputs": [
214
+ {"id": "12e4a99d-883d-4da5-aa51-35817d94013e", "name": "json", "type": "JSON", "value": None},
215
+ {"id": "0d76e1e1-3a4b-4eb4-a606-f73d62cf1a7e", "name": "headers", "type": "JSON", "value": None},
216
+ {"id": "fecc16c3-400e-4fd3-8223-08366070e3b1", "name": "status_code", "type": "NUMBER", "value": None},
217
+ {"id": "17342c21-12bb-49ab-88ce-f144e0376b32", "name": "text", "type": "STRING", "value": None},
218
+ ],
211
219
  },
212
220
  api_node,
213
221
  )
@@ -1,7 +1,4 @@
1
- """Tests for serialization of workflows with ManualTrigger."""
2
-
3
1
  import pytest
4
- from typing import cast
5
2
 
6
3
  from vellum.workflows import BaseWorkflow
7
4
  from vellum.workflows.inputs.base import BaseInputs
@@ -9,95 +6,55 @@ from vellum.workflows.nodes.bases.base import BaseNode
9
6
  from vellum.workflows.state.base import BaseState
10
7
  from vellum.workflows.triggers.base import BaseTrigger
11
8
  from vellum.workflows.triggers.manual import ManualTrigger
12
- from vellum.workflows.types.core import JsonArray, JsonObject
13
9
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
14
10
 
15
11
 
16
- class Inputs(BaseInputs):
17
- input: str
18
-
19
-
20
- class SimpleNode(BaseNode):
21
- class Outputs(BaseNode.Outputs):
22
- output = Inputs.input
23
-
24
-
25
- def create_workflow(trigger=None):
26
- """Factory for creating test workflows."""
27
-
28
- class TestWorkflow(BaseWorkflow[Inputs, BaseState]):
29
- graph = trigger >> SimpleNode if trigger else SimpleNode
30
-
31
- class Outputs(BaseWorkflow.Outputs):
32
- output = SimpleNode.Outputs.output
33
-
34
- return TestWorkflow
35
-
36
-
37
- def serialize(workflow_class) -> JsonObject:
38
- """Helper to serialize a workflow."""
39
- return get_workflow_display(workflow_class=workflow_class).serialize()
40
-
41
-
42
12
  def test_manual_trigger_serialization():
43
13
  """Workflow with ManualTrigger serializes with triggers field."""
44
- result = serialize(create_workflow(ManualTrigger))
45
- triggers = cast(JsonArray, result["triggers"])
46
14
 
47
- assert len(triggers) == 1
48
- trigger = cast(JsonObject, triggers[0])
15
+ class SimpleNode(BaseNode):
16
+ pass
49
17
 
50
- assert trigger["type"] == "MANUAL"
51
- assert "id" in trigger
52
- assert "attributes" in trigger
53
- assert trigger["attributes"] == []
54
- assert "definition" not in trigger
18
+ class TestWorkflow(BaseWorkflow[BaseInputs, BaseState]):
19
+ graph = ManualTrigger >> SimpleNode
55
20
 
21
+ result = get_workflow_display(workflow_class=TestWorkflow).serialize()
22
+ assert "triggers" in result
23
+ triggers = result["triggers"]
24
+ assert isinstance(triggers, list)
56
25
 
57
- def test_no_trigger_serialization():
58
- """Workflow without trigger has no triggers field."""
59
- result = serialize(create_workflow())
60
- assert "triggers" not in result
26
+ assert len(triggers) == 1
27
+ assert triggers[0] == {"id": "b09c1902-3cca-4c79-b775-4c32e3e88466", "type": "MANUAL", "attributes": []}
61
28
 
62
29
 
63
30
  def test_manual_trigger_multiple_entrypoints():
64
31
  """ManualTrigger with multiple entrypoints."""
65
32
 
66
33
  class NodeA(BaseNode):
67
- class Outputs(BaseNode.Outputs):
68
- output = Inputs.input
34
+ pass
69
35
 
70
36
  class NodeB(BaseNode):
71
- class Outputs(BaseNode.Outputs):
72
- output = Inputs.input
37
+ pass
73
38
 
74
- class MultiWorkflow(BaseWorkflow[Inputs, BaseState]):
39
+ class MultiWorkflow(BaseWorkflow[BaseInputs, BaseState]):
75
40
  graph = ManualTrigger >> {NodeA, NodeB}
76
41
 
77
- class Outputs(BaseWorkflow.Outputs):
78
- output_a = NodeA.Outputs.output
79
- output_b = NodeB.Outputs.output
80
-
81
- result = serialize(MultiWorkflow)
82
- triggers = cast(JsonArray, result["triggers"])
83
- workflow_data = cast(JsonObject, result["workflow_raw_data"])
84
- nodes = cast(JsonArray, workflow_data["nodes"])
85
-
86
- assert len(triggers) == 1
87
- trigger = cast(JsonObject, triggers[0])
88
- assert trigger["type"] == "MANUAL"
89
- assert len([n for n in nodes if cast(JsonObject, n)["type"] == "GENERIC"]) >= 2
42
+ result = get_workflow_display(workflow_class=MultiWorkflow).serialize()
43
+ workflow_data = result["workflow_raw_data"]
44
+ assert isinstance(workflow_data, dict)
45
+ assert "nodes" in workflow_data
46
+ nodes = workflow_data["nodes"]
47
+ assert isinstance(nodes, list)
90
48
 
49
+ # entrypoint + 2 nodes
50
+ assert len(nodes) == 3
91
51
 
92
- def test_serialized_workflow_structure():
93
- """Verify complete structure of serialized workflow."""
94
- result = serialize(create_workflow(ManualTrigger))
95
- workflow_raw_data = cast(JsonObject, result["workflow_raw_data"])
96
- definition = cast(JsonObject, workflow_raw_data["definition"])
52
+ assert "triggers" in result
53
+ triggers = result["triggers"]
54
+ assert isinstance(triggers, list)
97
55
 
98
- assert result.keys() == {"workflow_raw_data", "input_variables", "state_variables", "output_variables", "triggers"}
99
- assert workflow_raw_data.keys() == {"nodes", "edges", "display_data", "definition", "output_values"}
100
- assert definition["name"] == "TestWorkflow"
56
+ assert len(triggers) == 1
57
+ assert triggers[0] == {"id": "b09c1902-3cca-4c79-b775-4c32e3e88466", "type": "MANUAL", "attributes": []}
101
58
 
102
59
 
103
60
  def test_unknown_trigger_type():
@@ -106,5 +63,11 @@ def test_unknown_trigger_type():
106
63
  class UnknownTrigger(BaseTrigger):
107
64
  pass
108
65
 
66
+ class SimpleNode(BaseNode):
67
+ pass
68
+
69
+ class TestWorkflow(BaseWorkflow[BaseInputs, BaseState]):
70
+ graph = UnknownTrigger >> SimpleNode
71
+
109
72
  with pytest.raises(ValueError, match="Unknown trigger type: UnknownTrigger"):
110
- serialize(create_workflow(UnknownTrigger))
73
+ get_workflow_display(workflow_class=TestWorkflow).serialize()
@@ -0,0 +1,167 @@
1
+ """Tests for serialization of workflows with SlackTrigger."""
2
+
3
+ from vellum.workflows import BaseWorkflow
4
+ from vellum.workflows.inputs.base import BaseInputs
5
+ from vellum.workflows.nodes.bases.base import BaseNode
6
+ from vellum.workflows.state.base import BaseState
7
+ from vellum.workflows.triggers.slack import SlackTrigger
8
+ from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
9
+
10
+
11
+ class Inputs(BaseInputs):
12
+ input: str
13
+
14
+
15
+ class SimpleNode(BaseNode):
16
+ class Outputs(BaseNode.Outputs):
17
+ output = Inputs.input
18
+
19
+
20
+ def test_slack_trigger_serialization() -> None:
21
+ """Workflow with SlackTrigger serializes with triggers field."""
22
+
23
+ class TestWorkflow(BaseWorkflow[Inputs, BaseState]):
24
+ graph = SlackTrigger >> SimpleNode
25
+
26
+ class Outputs(BaseWorkflow.Outputs):
27
+ output = SimpleNode.Outputs.output
28
+
29
+ result = get_workflow_display(workflow_class=TestWorkflow).serialize()
30
+
31
+ # Validate triggers structure
32
+ assert "triggers" in result
33
+ triggers = result["triggers"]
34
+ assert isinstance(triggers, list)
35
+ assert len(triggers) == 1
36
+
37
+ trigger = triggers[0]
38
+ assert isinstance(trigger, dict)
39
+ assert trigger["type"] == "SLACK_MESSAGE"
40
+ assert "id" in trigger
41
+
42
+ # Validate attributes
43
+ assert "attributes" in trigger
44
+ attributes = trigger["attributes"]
45
+ assert isinstance(attributes, list)
46
+ assert len(attributes) == 6
47
+
48
+ attribute_names = set()
49
+ for attribute in attributes:
50
+ assert isinstance(attribute, dict)
51
+ assert "name" in attribute
52
+ assert isinstance(attribute["name"], str)
53
+ attribute_names.add(attribute["name"])
54
+ assert attribute_names == {
55
+ "message",
56
+ "channel",
57
+ "user",
58
+ "timestamp",
59
+ "thread_ts",
60
+ "event_type",
61
+ }
62
+
63
+ for attribute in attributes:
64
+ assert isinstance(attribute, dict)
65
+ assert attribute["value"] is None
66
+ assert isinstance(attribute["id"], str)
67
+ assert attribute["id"]
68
+
69
+
70
+ def test_slack_trigger_multiple_entrypoints() -> None:
71
+ """SlackTrigger with multiple entrypoints."""
72
+
73
+ class NodeA(BaseNode):
74
+ class Outputs(BaseNode.Outputs):
75
+ output = Inputs.input
76
+
77
+ class NodeB(BaseNode):
78
+ class Outputs(BaseNode.Outputs):
79
+ output = Inputs.input
80
+
81
+ class MultiWorkflow(BaseWorkflow[Inputs, BaseState]):
82
+ graph = SlackTrigger >> {NodeA, NodeB}
83
+
84
+ class Outputs(BaseWorkflow.Outputs):
85
+ output_a = NodeA.Outputs.output
86
+ output_b = NodeB.Outputs.output
87
+
88
+ result = get_workflow_display(workflow_class=MultiWorkflow).serialize()
89
+
90
+ # Validate triggers
91
+ assert "triggers" in result
92
+ triggers = result["triggers"]
93
+ assert isinstance(triggers, list)
94
+ assert len(triggers) == 1
95
+
96
+ trigger = triggers[0]
97
+ assert isinstance(trigger, dict)
98
+ assert trigger["type"] == "SLACK_MESSAGE"
99
+
100
+ # Validate attributes
101
+ assert "attributes" in trigger
102
+ attributes = trigger["attributes"]
103
+ assert isinstance(attributes, list)
104
+ attribute_names = set()
105
+ for attribute in attributes:
106
+ assert isinstance(attribute, dict)
107
+ assert "name" in attribute
108
+ assert isinstance(attribute["name"], str)
109
+ attribute_names.add(attribute["name"])
110
+
111
+ assert attribute_names == {
112
+ "message",
113
+ "channel",
114
+ "user",
115
+ "timestamp",
116
+ "thread_ts",
117
+ "event_type",
118
+ }
119
+
120
+ # Validate nodes
121
+ assert "workflow_raw_data" in result
122
+ workflow_data = result["workflow_raw_data"]
123
+ assert isinstance(workflow_data, dict)
124
+ assert "nodes" in workflow_data
125
+ nodes = workflow_data["nodes"]
126
+ assert isinstance(nodes, list)
127
+
128
+ generic_nodes = [node for node in nodes if isinstance(node, dict) and node.get("type") == "GENERIC"]
129
+ assert len(generic_nodes) >= 2
130
+
131
+
132
+ def test_serialized_slack_workflow_structure() -> None:
133
+ """Verify complete structure of serialized workflow with SlackTrigger."""
134
+
135
+ class TestWorkflow(BaseWorkflow[Inputs, BaseState]):
136
+ graph = SlackTrigger >> SimpleNode
137
+
138
+ class Outputs(BaseWorkflow.Outputs):
139
+ output = SimpleNode.Outputs.output
140
+
141
+ result = get_workflow_display(workflow_class=TestWorkflow).serialize()
142
+
143
+ # Validate top-level structure
144
+ assert isinstance(result, dict)
145
+ assert set(result.keys()) == {
146
+ "workflow_raw_data",
147
+ "input_variables",
148
+ "state_variables",
149
+ "output_variables",
150
+ "triggers",
151
+ }
152
+
153
+ # Validate workflow_raw_data structure
154
+ workflow_raw_data = result["workflow_raw_data"]
155
+ assert isinstance(workflow_raw_data, dict)
156
+ assert set(workflow_raw_data.keys()) == {
157
+ "nodes",
158
+ "edges",
159
+ "display_data",
160
+ "definition",
161
+ "output_values",
162
+ }
163
+
164
+ # Validate definition
165
+ definition = workflow_raw_data["definition"]
166
+ assert isinstance(definition, dict)
167
+ assert definition["name"] == "TestWorkflow"
@@ -50,6 +50,7 @@ from vellum.workflows.references.execution_count import ExecutionCountReference
50
50
  from vellum.workflows.references.lazy import LazyReference
51
51
  from vellum.workflows.references.output import OutputReference
52
52
  from vellum.workflows.references.state_value import StateValueReference
53
+ from vellum.workflows.references.trigger import TriggerAttributeReference
53
54
  from vellum.workflows.references.vellum_secret import VellumSecretReference
54
55
  from vellum.workflows.references.workflow_input import WorkflowInputReference
55
56
  from vellum.workflows.types.core import JsonArray, JsonObject
@@ -347,6 +348,17 @@ def serialize_value(executable_id: UUID, display_context: "WorkflowDisplayContex
347
348
  "node_id": str(node_class_display.node_id),
348
349
  }
349
350
 
351
+ if isinstance(value, TriggerAttributeReference):
352
+ # Generate trigger ID using the same hash formula as in base_workflow_display.py
353
+ trigger_class = value.trigger_class
354
+ trigger_id = uuid4_from_hash(trigger_class.__qualname__)
355
+
356
+ return {
357
+ "type": "TRIGGER_ATTRIBUTE",
358
+ "trigger_id": str(trigger_id),
359
+ "attribute_id": str(value.id),
360
+ }
361
+
350
362
  if isinstance(value, list):
351
363
  serialized_items = []
352
364
  for item in value: