vellum-ai 1.2.5__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. vellum/__init__.py +8 -0
  2. vellum/client/README.md +1 -1
  3. vellum/client/core/client_wrapper.py +2 -2
  4. vellum/client/reference.md +0 -9
  5. vellum/client/resources/workflow_sandboxes/client.py +0 -12
  6. vellum/client/resources/workflow_sandboxes/raw_client.py +2 -10
  7. vellum/client/types/__init__.py +8 -0
  8. vellum/client/types/deployment_read.py +5 -5
  9. vellum/client/types/slim_deployment_read.py +5 -5
  10. vellum/client/types/slim_workflow_deployment.py +5 -5
  11. vellum/client/types/workflow_deployment_read.py +5 -5
  12. vellum/client/types/workflow_request_audio_input_request.py +30 -0
  13. vellum/client/types/workflow_request_document_input_request.py +30 -0
  14. vellum/client/types/workflow_request_image_input_request.py +30 -0
  15. vellum/client/types/workflow_request_input_request.py +8 -0
  16. vellum/client/types/workflow_request_video_input_request.py +30 -0
  17. vellum/types/workflow_request_audio_input_request.py +3 -0
  18. vellum/types/workflow_request_document_input_request.py +3 -0
  19. vellum/types/workflow_request_image_input_request.py +3 -0
  20. vellum/types/workflow_request_video_input_request.py +3 -0
  21. vellum/workflows/events/types.py +6 -1
  22. vellum/workflows/events/workflow.py +9 -2
  23. vellum/workflows/integrations/tests/test_mcp_service.py +106 -1
  24. vellum/workflows/nodes/__init__.py +2 -0
  25. vellum/workflows/nodes/displayable/__init__.py +2 -0
  26. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +11 -0
  27. vellum/workflows/nodes/displayable/web_search_node/__init__.py +3 -0
  28. vellum/workflows/nodes/displayable/web_search_node/node.py +133 -0
  29. vellum/workflows/nodes/displayable/web_search_node/tests/__init__.py +0 -0
  30. vellum/workflows/nodes/displayable/web_search_node/tests/test_node.py +319 -0
  31. vellum/workflows/resolvers/base.py +3 -2
  32. vellum/workflows/resolvers/resolver.py +62 -7
  33. vellum/workflows/resolvers/tests/test_resolver.py +79 -7
  34. vellum/workflows/resolvers/types.py +11 -0
  35. vellum/workflows/runner/runner.py +49 -1
  36. vellum/workflows/state/context.py +41 -7
  37. vellum/workflows/utils/zip.py +46 -0
  38. vellum/workflows/workflows/base.py +10 -0
  39. {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/METADATA +1 -1
  40. {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/RECORD +48 -34
  41. vellum_cli/tests/test_init.py +7 -24
  42. vellum_cli/tests/test_pull.py +27 -52
  43. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +7 -33
  44. vellum_ee/workflows/display/utils/events.py +3 -0
  45. vellum_ee/workflows/tests/test_server.py +115 -0
  46. {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/LICENSE +0 -0
  47. {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/WHEEL +0 -0
  48. {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/entry_points.txt +0 -0
@@ -11,6 +11,7 @@ from vellum.workflows.nodes.displayable import (
11
11
  PromptDeploymentNode,
12
12
  SearchNode,
13
13
  SubworkflowDeploymentNode,
14
+ WebSearchNode,
14
15
  )
15
16
  from vellum.workflows.nodes.displayable.bases import (
16
17
  BaseInlinePromptNode as BaseInlinePromptNode,
@@ -43,4 +44,5 @@ __all__ = [
43
44
  "PromptDeploymentNode",
44
45
  "SearchNode",
45
46
  "SubworkflowDeploymentNode",
47
+ "WebSearchNode",
46
48
  ]
@@ -14,6 +14,7 @@ from .prompt_deployment_node import PromptDeploymentNode
14
14
  from .search_node import SearchNode
15
15
  from .subworkflow_deployment_node import SubworkflowDeploymentNode
16
16
  from .tool_calling_node import ToolCallingNode
17
+ from .web_search_node import WebSearchNode
17
18
 
18
19
  __all__ = [
19
20
  "APINode",
@@ -31,5 +32,6 @@ __all__ = [
31
32
  "SearchNode",
32
33
  "TemplatingNode",
33
34
  "ToolCallingNode",
35
+ "WebSearchNode",
34
36
  "FinalOutputNode",
35
37
  ]
@@ -54,12 +54,23 @@ class FunctionCallNodeMixin:
54
54
  return function_call.value.arguments or {}
55
55
  return {}
56
56
 
57
+ def _extract_function_call_id(self) -> Optional[str]:
58
+ """Extract function call ID from function call output."""
59
+ current_index = getattr(self, "state").current_prompt_output_index
60
+ if self.function_call_output and len(self.function_call_output) > current_index:
61
+ function_call = self.function_call_output[current_index]
62
+ if function_call.type == "FUNCTION_CALL" and function_call.value is not None:
63
+ return function_call.value.id
64
+ return None
65
+
57
66
  def _add_function_result_to_chat_history(self, result: Any, state: ToolCallingState) -> None:
58
67
  """Add function execution result to chat history."""
68
+ function_call_id = self._extract_function_call_id()
59
69
  state.chat_history.append(
60
70
  ChatMessage(
61
71
  role="FUNCTION",
62
72
  content=StringChatMessageContent(value=json.dumps(result, cls=DefaultStateEncoder)),
73
+ source=function_call_id,
63
74
  )
64
75
  )
65
76
  with state.__quiet__():
@@ -0,0 +1,3 @@
1
+ from .node import WebSearchNode
2
+
3
+ __all__ = ["WebSearchNode"]
@@ -0,0 +1,133 @@
1
+ import logging
2
+ from typing import Any, ClassVar, Dict, List, Optional
3
+
4
+ from requests import Request, RequestException, Session
5
+ from requests.exceptions import JSONDecodeError
6
+
7
+ from vellum.workflows.errors.types import WorkflowErrorCode
8
+ from vellum.workflows.exceptions import NodeException
9
+ from vellum.workflows.nodes.bases import BaseNode
10
+ from vellum.workflows.outputs import BaseOutputs
11
+ from vellum.workflows.types.generics import StateType
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class WebSearchNode(BaseNode[StateType]):
17
+ """
18
+ Used to perform web search using SerpAPI.
19
+
20
+ query: str - The search query to execute
21
+ api_key: str - SerpAPI authentication key
22
+ num_results: int - Number of search results to return (default: 10)
23
+ location: Optional[str] - Geographic location filter for search
24
+ """
25
+
26
+ query: ClassVar[str] = ""
27
+ api_key: ClassVar[Optional[str]] = None
28
+ num_results: ClassVar[int] = 10
29
+ location: ClassVar[Optional[str]] = None
30
+
31
+ class Outputs(BaseOutputs):
32
+ """
33
+ The outputs of the WebSearchNode.
34
+
35
+ text: str - Concatenated search result snippets with titles
36
+ urls: List[str] - List of URLs from search results
37
+ results: List[Dict[str, Any]] - Raw search results from SerpAPI
38
+ """
39
+
40
+ text: str
41
+ urls: List[str]
42
+ results: List[Dict[str, Any]]
43
+
44
+ def _validate(self) -> None:
45
+ """Validate node inputs."""
46
+ if not self.query or not isinstance(self.query, str) or not self.query.strip():
47
+ raise NodeException(
48
+ "Query is required and must be a non-empty string", code=WorkflowErrorCode.INVALID_INPUTS
49
+ )
50
+
51
+ if self.api_key is None:
52
+ raise NodeException("API key is required", code=WorkflowErrorCode.INVALID_INPUTS)
53
+
54
+ if not isinstance(self.num_results, int) or self.num_results <= 0:
55
+ raise NodeException("num_results must be a positive integer", code=WorkflowErrorCode.INVALID_INPUTS)
56
+
57
+ def run(self) -> Outputs:
58
+ """Run the WebSearchNode to perform web search via SerpAPI."""
59
+ self._validate()
60
+
61
+ api_key_value = self.api_key
62
+
63
+ params = {
64
+ "q": self.query,
65
+ "api_key": api_key_value,
66
+ "num": self.num_results,
67
+ "engine": "google",
68
+ }
69
+
70
+ if self.location:
71
+ params["location"] = self.location
72
+
73
+ headers = {}
74
+ client_headers = self._context.vellum_client._client_wrapper.get_headers()
75
+ headers["User-Agent"] = client_headers.get("User-Agent")
76
+
77
+ try:
78
+ prepped = Request(method="GET", url="https://serpapi.com/search", params=params, headers=headers).prepare()
79
+ except Exception as e:
80
+ logger.exception("Failed to prepare SerpAPI request")
81
+ raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.PROVIDER_ERROR) from e
82
+
83
+ try:
84
+ with Session() as session:
85
+ response = session.send(prepped, timeout=30)
86
+ except RequestException as e:
87
+ logger.exception("SerpAPI request failed")
88
+ raise NodeException(f"HTTP request failed: {e}", code=WorkflowErrorCode.PROVIDER_ERROR) from e
89
+
90
+ if response.status_code == 401:
91
+ logger.error("SerpAPI authentication failed")
92
+ raise NodeException("Invalid API key", code=WorkflowErrorCode.INVALID_INPUTS)
93
+ elif response.status_code == 429:
94
+ logger.warning("SerpAPI rate limit exceeded")
95
+ raise NodeException("Rate limit exceeded", code=WorkflowErrorCode.PROVIDER_ERROR)
96
+ elif response.status_code >= 400:
97
+ logger.error(f"SerpAPI returned error status: {response.status_code}")
98
+ raise NodeException(f"SerpAPI error: HTTP {response.status_code}", code=WorkflowErrorCode.PROVIDER_ERROR)
99
+
100
+ try:
101
+ json_response = response.json()
102
+ except JSONDecodeError as e:
103
+ logger.exception("Failed to parse SerpAPI response as JSON")
104
+ raise NodeException(
105
+ f"Invalid JSON response from SerpAPI: {e}", code=WorkflowErrorCode.PROVIDER_ERROR
106
+ ) from e
107
+
108
+ if "error" in json_response:
109
+ error_msg = json_response["error"]
110
+ logger.error(f"SerpAPI returned error: {error_msg}")
111
+ raise NodeException(f"SerpAPI error: {error_msg}", code=WorkflowErrorCode.PROVIDER_ERROR)
112
+
113
+ organic_results = json_response.get("organic_results", [])
114
+
115
+ text_results = []
116
+ urls = []
117
+
118
+ for result in organic_results:
119
+ title = result.get("title", "")
120
+ snippet = result.get("snippet", "")
121
+ link = result.get("link", "")
122
+
123
+ if title and snippet:
124
+ text_results.append(f"{title}: {snippet}")
125
+ elif title:
126
+ text_results.append(title)
127
+ elif snippet:
128
+ text_results.append(snippet)
129
+
130
+ if link:
131
+ urls.append(link)
132
+
133
+ return self.Outputs(text="\n\n".join(text_results), urls=urls, results=organic_results)
@@ -0,0 +1,319 @@
1
+ import pytest
2
+ from unittest.mock import MagicMock
3
+
4
+ import requests
5
+
6
+ from vellum.workflows.errors.types import WorkflowErrorCode
7
+ from vellum.workflows.exceptions import NodeException
8
+ from vellum.workflows.inputs import BaseInputs
9
+ from vellum.workflows.state import BaseState
10
+ from vellum.workflows.state.base import StateMeta
11
+
12
+ from ..node import WebSearchNode
13
+
14
+
15
+ @pytest.fixture
16
+ def base_node_setup(vellum_client):
17
+ """Basic node setup with required inputs."""
18
+
19
+ class Inputs(BaseInputs):
20
+ query: str
21
+ api_key: str
22
+ num_results: int
23
+
24
+ class State(BaseState):
25
+ pass
26
+
27
+ class TestableWebSearchNode(WebSearchNode):
28
+ query = Inputs.query
29
+ api_key = Inputs.api_key
30
+ num_results = Inputs.num_results
31
+
32
+ state = State(meta=StateMeta(workflow_inputs=Inputs(query="test query", api_key="test_api_key", num_results=3)))
33
+ context = MagicMock()
34
+ context.vellum_client = vellum_client
35
+ node = TestableWebSearchNode(state=state, context=context)
36
+ return node
37
+
38
+
39
+ def test_successful_search_with_results(base_node_setup, requests_mock):
40
+ """Test successful SerpAPI search with typical organic results."""
41
+ # GIVEN a mock SerpAPI response with organic results
42
+ mock_response = {
43
+ "organic_results": [
44
+ {
45
+ "title": "First Result",
46
+ "snippet": "This is the first search result snippet",
47
+ "link": "https://example1.com",
48
+ "position": 1,
49
+ },
50
+ {
51
+ "title": "Second Result",
52
+ "snippet": "This is the second search result snippet",
53
+ "link": "https://example2.com",
54
+ "position": 2,
55
+ },
56
+ {
57
+ "title": "Third Result",
58
+ "snippet": "This is the third search result snippet",
59
+ "link": "https://example3.com",
60
+ "position": 3,
61
+ },
62
+ ]
63
+ }
64
+
65
+ requests_mock.get("https://serpapi.com/search", json=mock_response)
66
+
67
+ # WHEN we run the node
68
+ outputs = base_node_setup.run()
69
+
70
+ # THEN the text output should be properly formatted
71
+ expected_text = (
72
+ "First Result: This is the first search result snippet\n\n"
73
+ "Second Result: This is the second search result snippet\n\n"
74
+ "Third Result: This is the third search result snippet"
75
+ )
76
+ assert outputs.text == expected_text
77
+
78
+ # AND URLs should be extracted correctly
79
+ assert outputs.urls == ["https://example1.com", "https://example2.com", "https://example3.com"]
80
+
81
+ # AND raw results should be preserved
82
+ assert outputs.results == mock_response["organic_results"]
83
+
84
+ # AND the request should have the correct parameters
85
+ assert requests_mock.last_request.qs == {
86
+ "q": ["test query"],
87
+ "api_key": ["test_api_key"],
88
+ "num": ["3"],
89
+ "engine": ["google"],
90
+ }
91
+
92
+
93
+ def test_search_with_location_parameter(base_node_setup, requests_mock):
94
+ """Test that location parameter is properly passed to SerpAPI."""
95
+ # GIVEN a location parameter is set
96
+ base_node_setup.location = "New York, NY"
97
+
98
+ requests_mock.get("https://serpapi.com/search", json={"organic_results": []})
99
+
100
+ # WHEN we run the node
101
+ base_node_setup.run()
102
+
103
+ # THEN the location parameter should be included (URL encoding may lowercase)
104
+ assert "location" in requests_mock.last_request.qs
105
+ assert requests_mock.last_request.qs["location"][0].lower() == "new york, ny"
106
+
107
+
108
+ def test_authentication_error_401(base_node_setup, requests_mock):
109
+ """Test 401 authentication error raises NodeException with INVALID_INPUTS."""
110
+ # GIVEN SerpAPI returns a 401 authentication error
111
+ requests_mock.get("https://serpapi.com/search", status_code=401)
112
+
113
+ # WHEN we run the node
114
+ with pytest.raises(NodeException) as exc_info:
115
+ base_node_setup.run()
116
+
117
+ # THEN it should raise the appropriate error
118
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
119
+ assert "Invalid API key" in str(exc_info.value)
120
+
121
+
122
+ def test_rate_limit_error_429(base_node_setup, requests_mock):
123
+ """Test 429 rate limit error raises NodeException with PROVIDER_ERROR."""
124
+ # GIVEN SerpAPI returns a 429 rate limit error
125
+ requests_mock.get("https://serpapi.com/search", status_code=429)
126
+
127
+ # WHEN we run the node
128
+ with pytest.raises(NodeException) as exc_info:
129
+ base_node_setup.run()
130
+
131
+ # THEN it should raise the appropriate error
132
+ assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
133
+ assert "Rate limit exceeded" in str(exc_info.value)
134
+
135
+
136
+ def test_server_error_500(base_node_setup, requests_mock):
137
+ """Test 500+ server errors raise NodeException with PROVIDER_ERROR."""
138
+ # GIVEN SerpAPI returns a 500 server error
139
+ requests_mock.get("https://serpapi.com/search", status_code=500)
140
+
141
+ # WHEN we run the node
142
+ with pytest.raises(NodeException) as exc_info:
143
+ base_node_setup.run()
144
+
145
+ # THEN it should raise the appropriate error
146
+ assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
147
+ assert "SerpAPI error: HTTP 500" in str(exc_info.value)
148
+
149
+
150
+ def test_invalid_json_response(base_node_setup, requests_mock):
151
+ """Test non-JSON response raises appropriate NodeException."""
152
+ # GIVEN SerpAPI returns non-JSON content
153
+ requests_mock.get("https://serpapi.com/search", text="Not JSON")
154
+
155
+ # WHEN we run the node
156
+ with pytest.raises(NodeException) as exc_info:
157
+ base_node_setup.run()
158
+
159
+ # THEN it should raise the appropriate error
160
+ assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
161
+ assert "Invalid JSON response" in str(exc_info.value)
162
+
163
+
164
+ def test_serpapi_error_in_response(base_node_setup, requests_mock):
165
+ """Test SerpAPI error field in response raises NodeException."""
166
+ # GIVEN SerpAPI returns an error in the response body
167
+ requests_mock.get("https://serpapi.com/search", json={"error": "Invalid search parameters"})
168
+
169
+ # WHEN we run the node
170
+ with pytest.raises(NodeException) as exc_info:
171
+ base_node_setup.run()
172
+
173
+ # THEN it should raise the appropriate error
174
+ assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
175
+ assert "Invalid search parameters" in str(exc_info.value)
176
+
177
+
178
+ def test_empty_query_validation(vellum_client):
179
+ """Test empty query raises validation error."""
180
+
181
+ # GIVEN a node with an empty query
182
+ class TestNode(WebSearchNode):
183
+ query = ""
184
+ api_key = "test_key"
185
+ num_results = 10
186
+
187
+ context = MagicMock()
188
+ context.vellum_client = vellum_client
189
+ node = TestNode(state=BaseState(meta=StateMeta(workflow_inputs=BaseInputs())), context=context)
190
+
191
+ # WHEN we run the node
192
+ with pytest.raises(NodeException) as exc_info:
193
+ node.run()
194
+
195
+ # THEN it should raise a validation error
196
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
197
+ assert "Query is required" in str(exc_info.value)
198
+
199
+
200
+ def test_missing_api_key_validation(vellum_client):
201
+ """Test missing API key raises validation error."""
202
+
203
+ # GIVEN a node with no API key
204
+ class TestNode(WebSearchNode):
205
+ query = "test query"
206
+ api_key = None
207
+ num_results = 10
208
+
209
+ context = MagicMock()
210
+ context.vellum_client = vellum_client
211
+ node = TestNode(state=BaseState(meta=StateMeta(workflow_inputs=BaseInputs())), context=context)
212
+
213
+ # WHEN we run the node
214
+ with pytest.raises(NodeException) as exc_info:
215
+ node.run()
216
+
217
+ # THEN it should raise a validation error
218
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
219
+ assert "API key is required" in str(exc_info.value)
220
+
221
+
222
+ def test_invalid_num_results_validation(vellum_client):
223
+ """Test invalid num_results raises validation error."""
224
+
225
+ # GIVEN a node with invalid num_results
226
+ class TestNode(WebSearchNode):
227
+ query = "test query"
228
+ api_key = "test_key"
229
+ num_results = -1
230
+
231
+ context = MagicMock()
232
+ context.vellum_client = vellum_client
233
+ node = TestNode(state=BaseState(meta=StateMeta(workflow_inputs=BaseInputs())), context=context)
234
+
235
+ # WHEN we run the node
236
+ with pytest.raises(NodeException) as exc_info:
237
+ node.run()
238
+
239
+ # THEN it should raise a validation error
240
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
241
+ assert "num_results must be a positive integer" in str(exc_info.value)
242
+
243
+
244
+ def test_empty_organic_results(base_node_setup, requests_mock):
245
+ """Test handling of empty search results."""
246
+ # GIVEN SerpAPI returns no organic results
247
+ requests_mock.get("https://serpapi.com/search", json={"organic_results": []})
248
+
249
+ # WHEN we run the node
250
+ outputs = base_node_setup.run()
251
+
252
+ # THEN all outputs should be empty
253
+ assert outputs.text == ""
254
+ assert outputs.urls == []
255
+ assert outputs.results == []
256
+
257
+
258
+ def test_missing_fields_in_results(base_node_setup, requests_mock):
259
+ """Test handling of missing title, snippet, or link fields."""
260
+ # GIVEN SerpAPI returns results with missing fields
261
+ mock_response = {
262
+ "organic_results": [
263
+ {
264
+ "title": "Only Title",
265
+ # Missing snippet and link
266
+ },
267
+ {
268
+ "snippet": "Only snippet, no title or link"
269
+ # Missing title and link
270
+ },
271
+ {
272
+ "title": "Title with link",
273
+ "link": "https://example.com",
274
+ # Missing snippet
275
+ },
276
+ {
277
+ # All fields missing - should be skipped
278
+ "position": 4
279
+ },
280
+ ]
281
+ }
282
+
283
+ requests_mock.get("https://serpapi.com/search", json=mock_response)
284
+
285
+ # WHEN we run the node
286
+ outputs = base_node_setup.run()
287
+
288
+ # THEN text should handle missing fields gracefully
289
+ expected_text = "Only Title\n\n" "Only snippet, no title or link\n\n" "Title with link"
290
+ assert outputs.text == expected_text
291
+
292
+ # AND URLs should only include valid links
293
+ assert outputs.urls == ["https://example.com"]
294
+
295
+
296
+ def test_request_timeout_handling(base_node_setup, requests_mock):
297
+ """Test network timeout raises appropriate error."""
298
+ # GIVEN a network timeout occurs
299
+ requests_mock.get("https://serpapi.com/search", exc=requests.exceptions.Timeout("Connection timed out"))
300
+
301
+ # WHEN we run the node
302
+ with pytest.raises(NodeException) as exc_info:
303
+ base_node_setup.run()
304
+
305
+ # THEN it should raise a provider error
306
+ assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
307
+ assert "HTTP request failed" in str(exc_info.value)
308
+
309
+
310
+ def test_user_agent_header_included(base_node_setup, requests_mock):
311
+ """Test that User-Agent header from vellum_client is included."""
312
+ # GIVEN a successful request
313
+ requests_mock.get("https://serpapi.com/search", json={"organic_results": []})
314
+
315
+ # WHEN we run the node
316
+ base_node_setup.run()
317
+
318
+ # THEN the User-Agent header should be included
319
+ assert requests_mock.last_request.headers["User-Agent"] == "vellum-python-sdk/1.0.0"
@@ -1,8 +1,9 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from uuid import UUID
3
- from typing import TYPE_CHECKING, Iterator, Optional, Type
3
+ from typing import TYPE_CHECKING, Iterator, Optional, Type, Union
4
4
 
5
5
  from vellum.workflows.events.workflow import WorkflowEvent
6
+ from vellum.workflows.resolvers.types import LoadStateResult
6
7
  from vellum.workflows.state.base import BaseState
7
8
 
8
9
  if TYPE_CHECKING:
@@ -28,5 +29,5 @@ class BaseWorkflowResolver(ABC):
28
29
  pass
29
30
 
30
31
  @abstractmethod
31
- def load_state(self, previous_execution_id: Optional[UUID] = None) -> Optional[BaseState]:
32
+ def load_state(self, previous_execution_id: Optional[Union[UUID, str]] = None) -> Optional[LoadStateResult]:
32
33
  pass
@@ -1,10 +1,13 @@
1
1
  import logging
2
2
  from uuid import UUID
3
- from typing import Iterator, Optional
3
+ from typing import Iterator, List, Optional, Tuple, Union
4
4
 
5
+ from vellum.client.types.vellum_span import VellumSpan
6
+ from vellum.client.types.workflow_execution_initiated_event import WorkflowExecutionInitiatedEvent
5
7
  from vellum.workflows.events.workflow import WorkflowEvent
6
8
  from vellum.workflows.resolvers.base import BaseWorkflowResolver
7
- from vellum.workflows.state.base import BaseState, StateMeta
9
+ from vellum.workflows.resolvers.types import LoadStateResult
10
+ from vellum.workflows.state.base import BaseState
8
11
 
9
12
  logger = logging.getLogger(__name__)
10
13
 
@@ -16,7 +19,42 @@ class VellumResolver(BaseWorkflowResolver):
16
19
  def get_state_snapshot_history(self) -> Iterator[BaseState]:
17
20
  return iter([])
18
21
 
19
- def load_state(self, previous_execution_id: Optional[UUID] = None) -> Optional[BaseState]:
22
+ def _find_previous_and_root_span(
23
+ self, execution_id: str, spans: List[VellumSpan]
24
+ ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
25
+ previous_trace_id: Optional[str] = None
26
+ root_trace_id: Optional[str] = None
27
+ previous_span_id: Optional[str] = None
28
+ root_span_id: Optional[str] = None
29
+
30
+ for span in spans:
31
+ # Look for workflow execution spans with matching ID first
32
+ if span.name == "workflow.execution" and span.span_id == execution_id:
33
+ # Find the WorkflowExecutionInitiatedEvent in the span's events
34
+ initiated_event = next(
35
+ (event for event in span.events if isinstance(event, WorkflowExecutionInitiatedEvent)), None
36
+ )
37
+ if initiated_event:
38
+ previous_trace_id = initiated_event.trace_id
39
+ previous_span_id = initiated_event.span_id
40
+ links = initiated_event.links
41
+ if links:
42
+ root_span = next((link for link in links if link.type == "ROOT_SPAN"), None)
43
+ if root_span:
44
+ root_trace_id = root_span.trace_id
45
+ root_span_id = root_span.span_context.span_id
46
+ else:
47
+ # no links means this is the first execution
48
+ root_trace_id = initiated_event.trace_id
49
+ root_span_id = initiated_event.span_id
50
+ break
51
+
52
+ return previous_trace_id, root_trace_id, previous_span_id, root_span_id
53
+
54
+ def load_state(self, previous_execution_id: Optional[Union[UUID, str]] = None) -> Optional[LoadStateResult]:
55
+ if isinstance(previous_execution_id, UUID):
56
+ previous_execution_id = str(previous_execution_id)
57
+
20
58
  if previous_execution_id is None:
21
59
  return None
22
60
 
@@ -26,17 +64,34 @@ class VellumResolver(BaseWorkflowResolver):
26
64
 
27
65
  client = self._context.vellum_client
28
66
  response = client.workflow_executions.retrieve_workflow_execution_detail(
29
- execution_id=str(previous_execution_id),
67
+ execution_id=previous_execution_id,
30
68
  )
31
69
 
32
70
  if response.state is None:
33
71
  return None
34
72
 
35
- meta = StateMeta.model_validate(response.state.pop("meta"))
73
+ previous_trace_id, root_trace_id, previous_span_id, root_span_id = self._find_previous_and_root_span(
74
+ previous_execution_id, response.spans
75
+ )
76
+
77
+ if previous_trace_id is None or root_trace_id is None or previous_span_id is None or root_span_id is None:
78
+ logger.warning("Could not find required execution events for state loading")
79
+ return None
80
+
81
+ if "meta" in response.state:
82
+ response.state.pop("meta")
36
83
 
37
84
  if self._workflow_class:
38
85
  state_class = self._workflow_class.get_state_class()
39
- return state_class(**response.state, meta=meta)
86
+ state = state_class(**response.state)
40
87
  else:
41
88
  logger.warning("No workflow class registered, falling back to BaseState")
42
- return BaseState(**response.state, meta=meta)
89
+ state = BaseState(**response.state)
90
+
91
+ return LoadStateResult(
92
+ state=state,
93
+ previous_trace_id=previous_trace_id,
94
+ previous_span_id=previous_span_id,
95
+ root_trace_id=root_trace_id,
96
+ root_span_id=root_span_id,
97
+ )