vellum-ai 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. vellum/__init__.py +40 -0
  2. vellum/client/core/client_wrapper.py +2 -2
  3. vellum/client/core/pydantic_utilities.py +3 -2
  4. vellum/client/reference.md +16 -0
  5. vellum/client/resources/workflow_executions/client.py +28 -4
  6. vellum/client/resources/workflow_executions/raw_client.py +32 -2
  7. vellum/client/types/__init__.py +40 -0
  8. vellum/client/types/audio_input_request.py +30 -0
  9. vellum/client/types/delimiter_chunker_config.py +20 -0
  10. vellum/client/types/delimiter_chunker_config_request.py +20 -0
  11. vellum/client/types/delimiter_chunking.py +21 -0
  12. vellum/client/types/delimiter_chunking_request.py +21 -0
  13. vellum/client/types/document_index_chunking.py +4 -1
  14. vellum/client/types/document_index_chunking_request.py +2 -1
  15. vellum/client/types/document_input_request.py +30 -0
  16. vellum/client/types/execution_audio_vellum_value.py +31 -0
  17. vellum/client/types/execution_document_vellum_value.py +31 -0
  18. vellum/client/types/execution_image_vellum_value.py +31 -0
  19. vellum/client/types/execution_vellum_value.py +8 -0
  20. vellum/client/types/execution_video_vellum_value.py +31 -0
  21. vellum/client/types/image_input_request.py +30 -0
  22. vellum/client/types/logical_operator.py +1 -0
  23. vellum/client/types/node_input_compiled_audio_value.py +23 -0
  24. vellum/client/types/node_input_compiled_document_value.py +23 -0
  25. vellum/client/types/node_input_compiled_image_value.py +23 -0
  26. vellum/client/types/node_input_compiled_video_value.py +23 -0
  27. vellum/client/types/node_input_variable_compiled_value.py +8 -0
  28. vellum/client/types/prompt_deployment_input_request.py +13 -1
  29. vellum/client/types/prompt_request_audio_input.py +26 -0
  30. vellum/client/types/prompt_request_document_input.py +26 -0
  31. vellum/client/types/prompt_request_image_input.py +26 -0
  32. vellum/client/types/prompt_request_input.py +13 -1
  33. vellum/client/types/prompt_request_video_input.py +26 -0
  34. vellum/client/types/video_input_request.py +30 -0
  35. vellum/types/audio_input_request.py +3 -0
  36. vellum/types/delimiter_chunker_config.py +3 -0
  37. vellum/types/delimiter_chunker_config_request.py +3 -0
  38. vellum/types/delimiter_chunking.py +3 -0
  39. vellum/types/delimiter_chunking_request.py +3 -0
  40. vellum/types/document_input_request.py +3 -0
  41. vellum/types/execution_audio_vellum_value.py +3 -0
  42. vellum/types/execution_document_vellum_value.py +3 -0
  43. vellum/types/execution_image_vellum_value.py +3 -0
  44. vellum/types/execution_video_vellum_value.py +3 -0
  45. vellum/types/image_input_request.py +3 -0
  46. vellum/types/node_input_compiled_audio_value.py +3 -0
  47. vellum/types/node_input_compiled_document_value.py +3 -0
  48. vellum/types/node_input_compiled_image_value.py +3 -0
  49. vellum/types/node_input_compiled_video_value.py +3 -0
  50. vellum/types/prompt_request_audio_input.py +3 -0
  51. vellum/types/prompt_request_document_input.py +3 -0
  52. vellum/types/prompt_request_image_input.py +3 -0
  53. vellum/types/prompt_request_video_input.py +3 -0
  54. vellum/types/video_input_request.py +3 -0
  55. vellum/workflows/context.py +27 -9
  56. vellum/workflows/events/context.py +53 -78
  57. vellum/workflows/events/node.py +5 -5
  58. vellum/workflows/events/relational_threads.py +41 -0
  59. vellum/workflows/events/tests/test_basic_workflow.py +50 -0
  60. vellum/workflows/events/workflow.py +12 -1
  61. vellum/workflows/expressions/contains.py +7 -0
  62. vellum/workflows/expressions/tests/test_contains.py +175 -0
  63. vellum/workflows/graph/graph.py +52 -8
  64. vellum/workflows/graph/tests/test_graph.py +17 -0
  65. vellum/workflows/integrations/mcp_service.py +35 -5
  66. vellum/workflows/integrations/tests/test_mcp_service.py +81 -0
  67. vellum/workflows/nodes/core/error_node/node.py +4 -0
  68. vellum/workflows/nodes/core/map_node/node.py +7 -0
  69. vellum/workflows/nodes/core/map_node/tests/test_node.py +19 -0
  70. vellum/workflows/nodes/displayable/final_output_node/node.py +4 -0
  71. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +1 -1
  72. vellum/workflows/ports/node_ports.py +3 -0
  73. vellum/workflows/ports/port.py +7 -0
  74. vellum/workflows/state/context.py +35 -4
  75. vellum/workflows/utils/uuids.py +15 -0
  76. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/METADATA +1 -1
  77. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/RECORD +85 -39
  78. vellum_ee/workflows/display/nodes/vellum/error_node.py +1 -5
  79. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +1 -5
  80. vellum_ee/workflows/display/utils/events.py +24 -0
  81. vellum_ee/workflows/display/utils/tests/test_events.py +69 -0
  82. vellum_ee/workflows/tests/test_server.py +95 -0
  83. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/LICENSE +0 -0
  84. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/WHEEL +0 -0
  85. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/entry_points.txt +0 -0
@@ -9,40 +9,67 @@ if TYPE_CHECKING:
9
9
  from vellum.workflows.nodes.bases.base import BaseNode
10
10
  from vellum.workflows.ports.port import Port
11
11
 
12
+
13
+ class NoPortsNode:
14
+ """Wrapper for nodes that have no ports defined."""
15
+
16
+ def __init__(self, node_class: Type["BaseNode"]):
17
+ self.node_class = node_class
18
+
19
+ def __repr__(self) -> str:
20
+ return self.node_class.__name__
21
+
22
+ def __rshift__(self, other: "GraphTarget") -> "Graph":
23
+ raise ValueError(
24
+ f"Cannot create edges from {self.node_class.__name__} because it has no ports defined. "
25
+ f"Nodes with empty Ports classes cannot be connected to other nodes."
26
+ )
27
+
28
+
12
29
  GraphTargetOfSets = Union[
13
30
  Set[NodeType],
14
31
  Set["Graph"],
15
32
  Set["Port"],
16
- Set[Union[Type["BaseNode"], "Graph", "Port"]],
33
+ Set[Union[Type["BaseNode"], "Graph", "Port", "NoPortsNode"]],
17
34
  ]
18
35
 
19
36
  GraphTarget = Union[
20
37
  Type["BaseNode"],
21
38
  "Port",
22
39
  "Graph",
40
+ "NoPortsNode",
23
41
  GraphTargetOfSets,
24
42
  ]
25
43
 
26
44
 
27
45
  class Graph:
28
- _entrypoints: Set["Port"]
46
+ _entrypoints: Set[Union["Port", "NoPortsNode"]]
29
47
  _edges: List[Edge]
30
- _terminals: Set["Port"]
31
-
32
- def __init__(self, entrypoints: Set["Port"], edges: List[Edge], terminals: Set["Port"]):
48
+ _terminals: Set[Union["Port", "NoPortsNode"]]
49
+
50
+ def __init__(
51
+ self,
52
+ entrypoints: Set[Union["Port", "NoPortsNode"]],
53
+ edges: List[Edge],
54
+ terminals: Set[Union["Port", "NoPortsNode"]],
55
+ ):
33
56
  self._edges = edges
34
57
  self._entrypoints = entrypoints
35
58
  self._terminals = terminals
36
59
 
37
60
  @staticmethod
38
61
  def from_port(port: "Port") -> "Graph":
39
- ports = {port}
62
+ ports: Set[Union["Port", "NoPortsNode"]] = {port}
40
63
  return Graph(entrypoints=ports, edges=[], terminals=ports)
41
64
 
42
65
  @staticmethod
43
66
  def from_node(node: Type["BaseNode"]) -> "Graph":
44
67
  ports = {port for port in node.Ports}
45
- return Graph(entrypoints=ports, edges=[], terminals=ports)
68
+ if not ports:
69
+ no_ports_node = NoPortsNode(node)
70
+ return Graph(entrypoints={no_ports_node}, edges=[], terminals={no_ports_node})
71
+ ports_set: Set[Union["Port", "NoPortsNode"]] = set(ports)
72
+ return Graph(entrypoints=ports_set, edges=[], terminals=ports_set)
46
73
 
47
74
  @staticmethod
48
75
  def from_set(targets: GraphTargetOfSets) -> "Graph":
@@ -73,10 +100,19 @@ class Graph:
73
100
  if not self._edges and not self._entrypoints:
74
101
  raise ValueError("Graph instance can only create new edges from nodes within existing edges")
75
102
 
103
+ if self._terminals and all(isinstance(terminal, NoPortsNode) for terminal in self._terminals):
104
+ terminal_names = [terminal.node_class.__name__ for terminal in self._terminals]
105
+ raise ValueError(
106
+ f"Cannot create edges from graph because all terminal nodes have no ports defined: "
107
+ f"{', '.join(terminal_names)}. Nodes with empty Ports classes cannot be connected to other nodes."
108
+ )
109
+
76
110
  if isinstance(other, set):
77
111
  new_terminals = set()
78
112
  for elem in other:
79
113
  for final_output_node in self._terminals:
114
+ if isinstance(final_output_node, NoPortsNode):
115
+ continue
80
116
  if isinstance(elem, Graph):
81
117
  midgraph = final_output_node >> set(elem.entrypoints)
82
118
  self._extend_edges(midgraph.edges)
@@ -98,6 +134,8 @@ class Graph:
98
134
 
99
135
  if isinstance(other, Graph):
100
136
  for final_output_node in self._terminals:
137
+ if isinstance(final_output_node, NoPortsNode):
138
+ continue
101
139
  midgraph = final_output_node >> set(other.entrypoints)
102
140
  self._extend_edges(midgraph.edges)
103
141
  self._extend_edges(other.edges)
@@ -106,6 +144,8 @@ class Graph:
106
144
 
107
145
  if hasattr(other, "Ports"):
108
146
  for final_output_node in self._terminals:
147
+ if isinstance(final_output_node, NoPortsNode):
148
+ continue
109
149
  subgraph = final_output_node >> other
110
150
  self._extend_edges(subgraph.edges)
111
151
  self._terminals = {port for port in other.Ports}
@@ -113,6 +153,8 @@ class Graph:
113
153
 
114
154
  # other is a Port
115
155
  for final_output_node in self._terminals:
156
+ if isinstance(final_output_node, NoPortsNode):
157
+ continue
116
158
  subgraph = final_output_node >> other
117
159
  self._extend_edges(subgraph.edges)
118
160
  self._terminals = {other}
@@ -238,8 +280,10 @@ class Graph:
238
280
 
239
281
  return "\n".join(lines)
240
282
 
241
- def _get_port_name(self, port: "Port") -> str:
283
+ def _get_port_name(self, port: Union["Port", "NoPortsNode"]) -> str:
242
284
  """Get a readable name for a port."""
285
+ if isinstance(port, NoPortsNode):
286
+ return f"{port.node_class.__name__} (no ports)"
243
287
  try:
244
288
  if hasattr(port, "node_class") and hasattr(port.node_class, "__name__"):
245
289
  node_name = port.node_class.__name__
@@ -583,3 +583,20 @@ def test_graph__str_single_node():
583
583
  # THEN it shows the single node
584
584
  assert "SingleNode.default" in result
585
585
  assert "Graph:" in result
586
+
587
+
588
+ def test_graph__from_node_with_empty_ports():
589
+ """
590
+ Tests that building a graph from a single node with empty Ports class generates 1 node.
591
+ """
592
+
593
+ # GIVEN a node with an empty Ports class
594
+ class NodeWithEmptyPorts(BaseNode):
595
+ class Ports(BaseNode.Ports):
596
+ pass
597
+
598
+ # WHEN we create a graph from the node
599
+ graph = Graph.from_node(NodeWithEmptyPorts)
600
+
601
+ # THEN the graph should have exactly 1 node
602
+ assert len(list(graph.nodes)) == 1
@@ -73,7 +73,7 @@ class MCPHttpClient:
73
73
  # Prepare headers
74
74
  headers = {
75
75
  "Content-Type": "application/json",
76
- "Accept": "application/json",
76
+ "Accept": "application/json, text/event-stream",
77
77
  }
78
78
 
79
79
  # Include session ID if we have one
@@ -88,11 +88,41 @@ class MCPHttpClient:
88
88
  # Check for session ID in response headers
89
89
  if "Mcp-Session-Id" in response.headers:
90
90
  self.session_id = response.headers["Mcp-Session-Id"]
91
- logger.debug(f"Received session ID: {self.session_id}")
92
91
 
93
- # Handle JSON response
94
- response_data = response.json()
95
- logger.debug(f"Received response: {json.dumps(response_data, indent=2)}")
92
+ # Handle response based on content type
93
+ content_type = response.headers.get("content-type", "").lower()
94
+
95
+ if "text/event-stream" in content_type:
96
+ # Handle SSE response
97
+ response_text = response.text
98
+
99
+ # Parse SSE format to extract JSON data
100
+ lines = response_text.strip().split("\n")
101
+ json_data = None
102
+
103
+ for line in lines:
104
+ if line.startswith("data: "):
105
+ data_content = line[6:] # Remove 'data: ' prefix
106
+ if data_content.strip() and data_content != "[DONE]":
107
+ try:
108
+ json_data = json.loads(data_content)
109
+ break
110
+ except json.JSONDecodeError:
111
+ continue
112
+
113
+ if json_data is None:
114
+ raise Exception("No valid JSON data found in SSE response")
115
+
116
+ response_data = json_data
117
+ else:
118
+ # Handle regular JSON response
119
+ if not response.text.strip():
120
+ raise Exception("Empty response received from server")
121
+
122
+ try:
123
+ response_data = response.json()
124
+ except json.JSONDecodeError as e:
125
+ raise Exception(f"Invalid JSON response: {str(e)}")
96
126
 
97
127
  if "error" in response_data:
98
128
  raise Exception(f"MCP Error: {response_data['error']}")
@@ -0,0 +1,81 @@
1
+ import asyncio
2
+ import json
3
+ from unittest import mock
4
+
5
+ from vellum.workflows.integrations.mcp_service import MCPHttpClient
6
+
7
+
8
+ def test_mcp_http_client_sse_response():
9
+ """Test that SSE responses are correctly parsed to JSON"""
10
+ # GIVEN an SSE response from the server
11
+ sample_sse_response = (
12
+ "event: message\n"
13
+ 'data: {"result":{"protocolVersion":"2025-06-18",'
14
+ '"capabilities":{"tools":{"listChanged":true}},'
15
+ '"serverInfo":{"name":"TestServer","version":"1.0.0"},'
16
+ '"instructions":"Test server for unit tests."},'
17
+ '"jsonrpc":"2.0","id":1}\n\n'
18
+ )
19
+ expected_json = {
20
+ "result": {
21
+ "protocolVersion": "2025-06-18",
22
+ "capabilities": {"tools": {"listChanged": True}},
23
+ "serverInfo": {"name": "TestServer", "version": "1.0.0"},
24
+ "instructions": "Test server for unit tests.",
25
+ },
26
+ "jsonrpc": "2.0",
27
+ "id": 1,
28
+ }
29
+
30
+ with mock.patch("vellum.workflows.integrations.mcp_service.httpx.AsyncClient") as mock_client_class:
31
+ mock_client = mock.AsyncMock()
32
+ mock_client_class.return_value = mock_client
33
+
34
+ mock_response = mock.Mock()
35
+ mock_response.headers = {"content-type": "text/event-stream"}
36
+ mock_response.text = sample_sse_response
37
+ mock_client.post.return_value = mock_response
38
+
39
+ # WHEN we send a request through the MCP client
40
+ async def test_request():
41
+ async with MCPHttpClient("https://test.server.com", {}) as client:
42
+ result = await client._send_request("initialize", {"test": "params"})
43
+ return result
44
+
45
+ result = asyncio.run(test_request())
46
+
47
+ # THEN the SSE response should be parsed correctly to JSON
48
+ assert result == expected_json
49
+
50
+ # AND the request should have been made with correct headers
51
+ mock_client.post.assert_called_once()
52
+ call_args = mock_client.post.call_args
53
+ assert call_args[1]["headers"]["Accept"] == "application/json, text/event-stream"
54
+ assert call_args[1]["headers"]["Content-Type"] == "application/json"
55
+
56
+
57
+ def test_mcp_http_client_json_response():
58
+ """Test that regular JSON responses still work"""
59
+ # GIVEN a regular JSON response from the server
60
+ sample_json_response = {"result": {"test": "data"}, "jsonrpc": "2.0", "id": 1}
61
+
62
+ with mock.patch("vellum.workflows.integrations.mcp_service.httpx.AsyncClient") as mock_client_class:
63
+ mock_client = mock.AsyncMock()
64
+ mock_client_class.return_value = mock_client
65
+
66
+ mock_response = mock.Mock()
67
+ mock_response.headers = {"content-type": "application/json"}
68
+ mock_response.text = json.dumps(sample_json_response)
69
+ mock_response.json.return_value = sample_json_response
70
+ mock_client.post.return_value = mock_response
71
+
72
+ # WHEN we send a request through the MCP client
73
+ async def test_request():
74
+ async with MCPHttpClient("https://test.server.com", {}) as client:
75
+ result = await client._send_request("initialize", {"test": "params"})
76
+ return result
77
+
78
+ result = asyncio.run(test_request())
79
+
80
+ # THEN the JSON response should be returned as expected
81
+ assert result == sample_json_response
@@ -4,6 +4,7 @@ from vellum.client.types.vellum_error import VellumError
4
4
  from vellum.workflows.errors.types import WorkflowError, WorkflowErrorCode, vellum_error_to_workflow_error
5
5
  from vellum.workflows.exceptions import NodeException
6
6
  from vellum.workflows.nodes.bases.base import BaseNode
7
+ from vellum.workflows.ports import NodePorts
7
8
 
8
9
 
9
10
  class ErrorNode(BaseNode):
@@ -15,6 +16,9 @@ class ErrorNode(BaseNode):
15
16
 
16
17
  error: ClassVar[Union[str, WorkflowError, VellumError]]
17
18
 
19
+ class Ports(NodePorts):
20
+ pass
21
+
18
22
  def run(self) -> BaseNode.Outputs:
19
23
  if isinstance(self.error, str):
20
24
  raise NodeException(message=self.error, code=WorkflowErrorCode.USER_DEFINED_ERROR)
@@ -30,6 +30,7 @@ from vellum.workflows.outputs.base import BaseOutput
30
30
  from vellum.workflows.references.output import OutputReference
31
31
  from vellum.workflows.state.context import WorkflowContext
32
32
  from vellum.workflows.types.generics import StateType
33
+ from vellum.workflows.utils.uuids import uuid4_from_hash
33
34
  from vellum.workflows.workflows.event_filters import all_workflow_event_filter
34
35
 
35
36
  if TYPE_CHECKING:
@@ -211,4 +212,10 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
211
212
  annotation = List[parameter_type] # type: ignore[valid-type]
212
213
 
213
214
  previous_annotations = {prev: annotation for prev in outputs_class.__annotations__ if not prev.startswith("_")}
215
+ # Map node output is a list of the same type so we use annotation=List[parameter_type] and not reference
216
+ # class Outputs(BaseOutputs):
217
+ # value: List[str]
214
218
  outputs_class.__annotations__ = {**previous_annotations, reference.name: annotation}
219
+
220
+ output_id = uuid4_from_hash(f"{cls.__id__}|{reference.name}")
221
+ cls.__output_ids__[reference.name] = output_id
@@ -119,6 +119,7 @@ def test_map_node__inner_try():
119
119
  # THEN the workflow should succeed
120
120
  assert outputs[-1].name == "final_output"
121
121
  assert len(outputs[-1].value) == 2
122
+ assert len(SimpleMapNode.__output_ids__) == 1
122
123
 
123
124
 
124
125
  def test_map_node__nested_map_node():
@@ -275,3 +276,21 @@ def test_map_node__shared_state_race_condition():
275
276
  # AND all results should be in correct order
276
277
  expected_result = ["a!", "b!", "c!", "d!", "e!", "f!"]
277
278
  assert final_result == expected_result, f"Failed on run {index}"
279
+
280
+
281
+ def test_map_node__output_ids():
282
+ class TestNode(BaseNode):
283
+ class Outputs(BaseOutputs):
284
+ value: str
285
+
286
+ class SimpleMapNodeWorkflow(BaseWorkflow[MapNode.SubworkflowInputs, BaseState]):
287
+ graph = TestNode
288
+
289
+ class Outputs(BaseWorkflow.Outputs):
290
+ final_output = TestNode.Outputs.value
291
+
292
+ class TestMapNode(MapNode):
293
+ items = [1, 2, 3]
294
+ subworkflow = SimpleMapNodeWorkflow
295
+
296
+ assert len(TestMapNode.__output_ids__) == 1
@@ -4,6 +4,7 @@ from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.nodes.bases import BaseNode
5
5
  from vellum.workflows.nodes.bases.base import BaseNodeMeta
6
6
  from vellum.workflows.nodes.utils import cast_to_output_type
7
+ from vellum.workflows.ports import NodePorts
7
8
  from vellum.workflows.types import MergeBehavior
8
9
  from vellum.workflows.types.generics import StateType
9
10
  from vellum.workflows.types.utils import get_original_base
@@ -47,6 +48,9 @@ class FinalOutputNode(BaseNode[StateType], Generic[StateType, _OutputType], meta
47
48
  class Trigger(BaseNode.Trigger):
48
49
  merge_behavior = MergeBehavior.AWAIT_ANY
49
50
 
51
+ class Ports(NodePorts):
52
+ pass
53
+
50
54
  class Outputs(BaseNode.Outputs):
51
55
  # We use our mypy plugin to override the _OutputType with the actual output type
52
56
  # for downstream references to this output.
@@ -232,7 +232,7 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
232
232
  )
233
233
 
234
234
  resolved_workflow = self._context.resolve_workflow_deployment(
235
- deployment_name=deployment_name, release_tag=self.release_tag
235
+ deployment_name=deployment_name, release_tag=self.release_tag, state=self.state
236
236
  )
237
237
  if resolved_workflow:
238
238
  yield from self._run_resolved_workflow(resolved_workflow)
@@ -40,6 +40,9 @@ class NodePorts(metaclass=_NodePortsMeta):
40
40
 
41
41
  invoked_ports: Set[Port] = set()
42
42
  all_ports = [port for port in self.__class__]
43
+ if not all_ports:
44
+ return set()
45
+
43
46
  enforce_single_invoked_conditional_port = validate_ports(all_ports)
44
47
 
45
48
  for port in all_ports:
@@ -9,6 +9,7 @@ from vellum.workflows.edges.edge import Edge
9
9
  from vellum.workflows.errors.types import WorkflowErrorCode
10
10
  from vellum.workflows.exceptions import NodeException
11
11
  from vellum.workflows.graph import Graph, GraphTarget
12
+ from vellum.workflows.graph.graph import NoPortsNode
12
13
  from vellum.workflows.state.base import BaseState
13
14
  from vellum.workflows.types.core import ConditionType
14
15
 
@@ -66,6 +67,12 @@ class Port:
66
67
  if isinstance(other, Port):
67
68
  return Graph.from_port(self) >> Graph.from_port(other)
68
69
 
70
+ if isinstance(other, NoPortsNode):
71
+ raise ValueError(
72
+ f"Cannot create edge to {other.node_class.__name__} because it has no ports defined. "
73
+ f"Nodes with empty Ports classes cannot be connected to other nodes."
74
+ )
75
+
69
76
  edge = Edge(from_port=self, to_node=other)
70
77
  if edge not in self._edges:
71
78
  self._edges.append(edge)
@@ -4,15 +4,17 @@ from uuid import uuid4
4
4
  from typing import TYPE_CHECKING, Dict, List, Optional, Type
5
5
 
6
6
  from vellum import Vellum
7
- from vellum.workflows.context import ExecutionContext, get_execution_context
7
+ from vellum.workflows.context import ExecutionContext, get_execution_context, set_execution_context
8
8
  from vellum.workflows.events.types import ExternalParentContext
9
9
  from vellum.workflows.nodes.mocks import MockNodeExecution, MockNodeExecutionArg
10
10
  from vellum.workflows.outputs.base import BaseOutputs
11
11
  from vellum.workflows.references.constant import ConstantValueReference
12
+ from vellum.workflows.utils.uuids import generate_workflow_deployment_prefix
12
13
  from vellum.workflows.vellum_client import create_vellum_client
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from vellum.workflows.events.workflow import WorkflowEvent
17
+ from vellum.workflows.state.base import BaseState
16
18
  from vellum.workflows.workflows.base import BaseWorkflow
17
19
 
18
20
 
@@ -23,11 +25,13 @@ class WorkflowContext:
23
25
  vellum_client: Optional[Vellum] = None,
24
26
  execution_context: Optional[ExecutionContext] = None,
25
27
  generated_files: Optional[dict[str, str]] = None,
28
+ namespace: Optional[str] = None,
26
29
  ):
27
30
  self._vellum_client = vellum_client
28
31
  self._event_queue: Optional[Queue["WorkflowEvent"]] = None
29
32
  self._node_output_mocks_map: Dict[Type[BaseOutputs], List[MockNodeExecution]] = {}
30
33
  self._execution_context = get_execution_context()
34
+ self._namespace = namespace
31
35
 
32
36
  if execution_context is not None:
33
37
  self._execution_context.trace_id = execution_context.trace_id
@@ -36,6 +40,8 @@ class WorkflowContext:
36
40
 
37
41
  if self._execution_context.parent_context is None:
38
42
  self._execution_context.parent_context = ExternalParentContext(span_id=uuid4())
43
+ # Propagate the updated context back to the global execution context
44
+ set_execution_context(self._execution_context)
39
45
 
40
46
  self._generated_files = generated_files
41
47
 
@@ -54,6 +60,10 @@ class WorkflowContext:
54
60
  def generated_files(self) -> Optional[dict[str, str]]:
55
61
  return self._generated_files
56
62
 
63
+ @cached_property
64
+ def namespace(self) -> Optional[str]:
65
+ return self._namespace
66
+
57
67
  @cached_property
58
68
  def node_output_mocks_map(self) -> Dict[Type[BaseOutputs], List[MockNodeExecution]]:
59
69
  return self._node_output_mocks_map
@@ -132,19 +142,40 @@ class WorkflowContext:
132
142
  def _get_all_node_output_mocks(self) -> List[MockNodeExecution]:
133
143
  return [mock for mocks in self._node_output_mocks_map.values() for mock in mocks]
134
144
 
135
- def resolve_workflow_deployment(self, deployment_name: str, release_tag: str) -> Optional["BaseWorkflow"]:
145
+ def resolve_workflow_deployment(
146
+ self, deployment_name: str, release_tag: str, state: "BaseState"
147
+ ) -> Optional["BaseWorkflow"]:
136
148
  """
137
149
  Resolve a workflow deployment by name and release tag.
138
150
 
139
151
  Args:
140
152
  deployment_name: The name of the workflow deployment
141
153
  release_tag: The release tag to resolve
154
+ state: The base state to pass to the workflow
142
155
 
143
156
  Returns:
144
157
  BaseWorkflow instance if found, None otherwise
145
158
  """
146
- return None
159
+ if not self.generated_files or not self.namespace:
160
+ return None
161
+
162
+ expected_prefix = generate_workflow_deployment_prefix(deployment_name, release_tag)
163
+
164
+ workflow_file_key = f"{expected_prefix}/workflow.py"
165
+ if workflow_file_key not in self.generated_files:
166
+ return None
167
+
168
+ try:
169
+ from vellum.workflows.workflows.base import BaseWorkflow
170
+
171
+ WorkflowClass = BaseWorkflow.load_from_module(f"{self.namespace}.{expected_prefix}")
172
+ workflow_instance = WorkflowClass(context=WorkflowContext.create_from(self), parent_state=state)
173
+ return workflow_instance
174
+ except Exception:
175
+ return None
147
176
 
148
177
  @classmethod
149
178
  def create_from(cls, context):
150
- return cls(vellum_client=context.vellum_client, generated_files=context.generated_files)
179
+ return cls(
180
+ vellum_client=context.vellum_client, generated_files=context.generated_files, namespace=context.namespace
181
+ )
@@ -2,6 +2,21 @@ import hashlib
2
2
  from uuid import UUID
3
3
 
4
4
 
5
+ def generate_workflow_deployment_prefix(deployment_name: str, release_tag: str) -> str:
6
+ """
7
+ Generate a workflow deployment prefix from deployment name and release tag.
8
+
9
+ Args:
10
+ deployment_name: The name of the workflow deployment
11
+ release_tag: The release tag to resolve
12
+
13
+ Returns:
14
+ The generated prefix in format vellum_workflow_deployment_{hash}
15
+ """
16
+ expected_hash = str(uuid4_from_hash(f"{deployment_name}|{release_tag}")).replace("-", "_")
17
+ return f"vellum_workflow_deployment_{expected_hash}"
18
+
19
+
5
20
  def uuid4_from_hash(input_str: str) -> UUID:
6
21
  # Create a SHA-256 hash of the input string
7
22
  hash_bytes = hashlib.sha256(input_str.encode()).digest()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 1.2.2
3
+ Version: 1.2.3
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0