vellum-ai 1.2.4__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +56 -0
- vellum/client/README.md +1 -1
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/reference.md +0 -9
- vellum/client/resources/workflow_sandboxes/client.py +0 -12
- vellum/client/resources/workflow_sandboxes/raw_client.py +2 -10
- vellum/client/resources/workflows/client.py +20 -0
- vellum/client/resources/workflows/raw_client.py +20 -0
- vellum/client/types/__init__.py +56 -0
- vellum/client/types/audio_input.py +30 -0
- vellum/client/types/code_executor_input.py +8 -0
- vellum/client/types/deployment_read.py +5 -5
- vellum/client/types/document_input.py +30 -0
- vellum/client/types/image_input.py +30 -0
- vellum/client/types/named_scenario_input_audio_variable_value_request.py +22 -0
- vellum/client/types/named_scenario_input_document_variable_value_request.py +22 -0
- vellum/client/types/named_scenario_input_image_variable_value_request.py +22 -0
- vellum/client/types/named_scenario_input_request.py +8 -0
- vellum/client/types/named_scenario_input_video_variable_value_request.py +22 -0
- vellum/client/types/named_test_case_audio_variable_value.py +26 -0
- vellum/client/types/named_test_case_audio_variable_value_request.py +26 -0
- vellum/client/types/named_test_case_document_variable_value.py +22 -0
- vellum/client/types/named_test_case_document_variable_value_request.py +22 -0
- vellum/client/types/named_test_case_image_variable_value.py +22 -0
- vellum/client/types/named_test_case_image_variable_value_request.py +22 -0
- vellum/client/types/named_test_case_variable_value.py +8 -0
- vellum/client/types/named_test_case_variable_value_request.py +8 -0
- vellum/client/types/named_test_case_video_variable_value.py +22 -0
- vellum/client/types/named_test_case_video_variable_value_request.py +22 -0
- vellum/client/types/node_execution_span_attributes.py +1 -0
- vellum/client/types/scenario_input.py +11 -1
- vellum/client/types/scenario_input_audio_variable_value.py +22 -0
- vellum/client/types/scenario_input_document_variable_value.py +22 -0
- vellum/client/types/scenario_input_image_variable_value.py +22 -0
- vellum/client/types/scenario_input_video_variable_value.py +22 -0
- vellum/client/types/slim_deployment_read.py +5 -5
- vellum/client/types/slim_workflow_deployment.py +5 -5
- vellum/client/types/span_link.py +1 -1
- vellum/client/types/span_link_type_enum.py +1 -1
- vellum/client/types/test_case_audio_variable_value.py +27 -0
- vellum/client/types/test_case_document_variable_value.py +27 -0
- vellum/client/types/test_case_image_variable_value.py +27 -0
- vellum/client/types/test_case_variable_value.py +8 -0
- vellum/client/types/test_case_video_variable_value.py +27 -0
- vellum/client/types/video_input.py +30 -0
- vellum/client/types/workflow_deployment_read.py +5 -5
- vellum/client/types/workflow_push_deployment_config_request.py +1 -0
- vellum/client/types/workflow_request_audio_input_request.py +30 -0
- vellum/client/types/workflow_request_document_input_request.py +30 -0
- vellum/client/types/workflow_request_image_input_request.py +30 -0
- vellum/client/types/workflow_request_input_request.py +8 -0
- vellum/client/types/workflow_request_video_input_request.py +30 -0
- vellum/types/audio_input.py +3 -0
- vellum/types/document_input.py +3 -0
- vellum/types/image_input.py +3 -0
- vellum/types/named_scenario_input_audio_variable_value_request.py +3 -0
- vellum/types/named_scenario_input_document_variable_value_request.py +3 -0
- vellum/types/named_scenario_input_image_variable_value_request.py +3 -0
- vellum/types/named_scenario_input_video_variable_value_request.py +3 -0
- vellum/types/named_test_case_audio_variable_value.py +3 -0
- vellum/types/named_test_case_audio_variable_value_request.py +3 -0
- vellum/types/named_test_case_document_variable_value.py +3 -0
- vellum/types/named_test_case_document_variable_value_request.py +3 -0
- vellum/types/named_test_case_image_variable_value.py +3 -0
- vellum/types/named_test_case_image_variable_value_request.py +3 -0
- vellum/types/named_test_case_video_variable_value.py +3 -0
- vellum/types/named_test_case_video_variable_value_request.py +3 -0
- vellum/types/scenario_input_audio_variable_value.py +3 -0
- vellum/types/scenario_input_document_variable_value.py +3 -0
- vellum/types/scenario_input_image_variable_value.py +3 -0
- vellum/types/scenario_input_video_variable_value.py +3 -0
- vellum/types/test_case_audio_variable_value.py +3 -0
- vellum/types/test_case_document_variable_value.py +3 -0
- vellum/types/test_case_image_variable_value.py +3 -0
- vellum/types/test_case_video_variable_value.py +3 -0
- vellum/types/video_input.py +3 -0
- vellum/types/workflow_request_audio_input_request.py +3 -0
- vellum/types/workflow_request_document_input_request.py +3 -0
- vellum/types/workflow_request_image_input_request.py +3 -0
- vellum/types/workflow_request_video_input_request.py +3 -0
- vellum/workflows/events/types.py +6 -1
- vellum/workflows/integrations/tests/test_mcp_service.py +106 -1
- vellum/workflows/nodes/__init__.py +2 -0
- vellum/workflows/nodes/displayable/__init__.py +2 -0
- vellum/workflows/nodes/displayable/web_search_node/__init__.py +3 -0
- vellum/workflows/nodes/displayable/web_search_node/node.py +133 -0
- vellum/workflows/resolvers/base.py +19 -1
- vellum/workflows/resolvers/resolver.py +97 -0
- vellum/workflows/resolvers/tests/test_resolver.py +131 -0
- vellum/workflows/resolvers/types.py +11 -0
- vellum/workflows/runner/runner.py +49 -1
- vellum/workflows/state/context.py +41 -7
- vellum/workflows/utils/zip.py +46 -0
- vellum/workflows/workflows/base.py +13 -0
- {vellum_ai-1.2.4.dist-info → vellum_ai-1.3.0.dist-info}/METADATA +1 -1
- {vellum_ai-1.2.4.dist-info → vellum_ai-1.3.0.dist-info}/RECORD +105 -43
- vellum_cli/tests/test_init.py +7 -24
- vellum_cli/tests/test_pull.py +27 -52
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +7 -33
- vellum_ee/workflows/display/utils/events.py +19 -1
- vellum_ee/workflows/display/utils/tests/test_events.py +42 -0
- vellum_ee/workflows/tests/test_server.py +115 -0
- {vellum_ai-1.2.4.dist-info → vellum_ai-1.3.0.dist-info}/LICENSE +0 -0
- {vellum_ai-1.2.4.dist-info → vellum_ai-1.3.0.dist-info}/WHEEL +0 -0
- {vellum_ai-1.2.4.dist-info → vellum_ai-1.3.0.dist-info}/entry_points.txt +0 -0
vellum/workflows/events/types.py
CHANGED
@@ -7,7 +7,6 @@ from pydantic import Field, GetCoreSchemaHandler, Tag, ValidationInfo
|
|
7
7
|
from pydantic_core import CoreSchema, core_schema
|
8
8
|
|
9
9
|
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
10
|
-
from vellum.client.types.span_link import SpanLink
|
11
10
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
12
11
|
from vellum.workflows.types.definition import VellumCodeResourceDefinition
|
13
12
|
from vellum.workflows.types.utils import datetime_now
|
@@ -86,6 +85,12 @@ class ExternalParentContext(BaseParentContext):
|
|
86
85
|
type: Literal["EXTERNAL"] = "EXTERNAL"
|
87
86
|
|
88
87
|
|
88
|
+
class SpanLink(UniversalBaseModel):
|
89
|
+
trace_id: str
|
90
|
+
type: Literal["TRIGGERED_BY", "PREVIOUS_SPAN", "ROOT_SPAN"]
|
91
|
+
span_context: "ParentContext"
|
92
|
+
|
93
|
+
|
89
94
|
def _cast_parent_context_discriminator(v: Any) -> Any:
|
90
95
|
if v in PARENT_CONTEXT_TYPES:
|
91
96
|
return v
|
@@ -1,10 +1,11 @@
|
|
1
|
+
import pytest
|
1
2
|
import asyncio
|
2
3
|
import json
|
3
4
|
from unittest import mock
|
4
5
|
|
5
6
|
from vellum.workflows.constants import AuthorizationType
|
6
7
|
from vellum.workflows.integrations.mcp_service import MCPHttpClient, MCPService
|
7
|
-
from vellum.workflows.types.definition import MCPServer
|
8
|
+
from vellum.workflows.types.definition import MCPServer, MCPToolDefinition
|
8
9
|
|
9
10
|
|
10
11
|
def test_mcp_http_client_sse_response():
|
@@ -118,3 +119,107 @@ def test_mcp_service_api_key_auth():
|
|
118
119
|
|
119
120
|
# THEN the custom API key header should be set correctly
|
120
121
|
assert headers == {"X-API-Key": "api-key-123"}
|
122
|
+
|
123
|
+
|
124
|
+
@pytest.mark.asyncio
|
125
|
+
async def test_mcp_http_client_empty_response():
|
126
|
+
"""Test that empty responses are handled gracefully"""
|
127
|
+
# GIVEN a mock response that returns empty content
|
128
|
+
mock_response = mock.Mock()
|
129
|
+
mock_response.headers = {"content-type": "application/json"}
|
130
|
+
mock_response.text = ""
|
131
|
+
|
132
|
+
# AND a mock httpx client that returns this response
|
133
|
+
with mock.patch("vellum.workflows.integrations.mcp_service.httpx.AsyncClient") as mock_client_class:
|
134
|
+
mock_client = mock.AsyncMock()
|
135
|
+
mock_client.post.return_value = mock_response
|
136
|
+
mock_client_class.return_value = mock_client
|
137
|
+
|
138
|
+
# WHEN we call initialize with an empty response
|
139
|
+
# THEN it should raise an exception about empty response
|
140
|
+
async with MCPHttpClient("https://test.server.com", {}) as client:
|
141
|
+
with pytest.raises(Exception, match="Empty response received from server"):
|
142
|
+
await client.initialize()
|
143
|
+
|
144
|
+
|
145
|
+
@pytest.mark.asyncio
|
146
|
+
async def test_mcp_http_client_invalid_sse_json():
|
147
|
+
"""Test that invalid JSON in SSE data is handled"""
|
148
|
+
# GIVEN an SSE response with invalid JSON
|
149
|
+
invalid_sse = """event: message
|
150
|
+
data: {invalid json}
|
151
|
+
|
152
|
+
"""
|
153
|
+
|
154
|
+
mock_response = mock.Mock()
|
155
|
+
mock_response.headers = {"content-type": "text/event-stream"}
|
156
|
+
mock_response.text = invalid_sse
|
157
|
+
|
158
|
+
with mock.patch("vellum.workflows.integrations.mcp_service.httpx.AsyncClient") as mock_client_class:
|
159
|
+
mock_client = mock.AsyncMock()
|
160
|
+
mock_client.post.return_value = mock_response
|
161
|
+
mock_client_class.return_value = mock_client
|
162
|
+
|
163
|
+
# WHEN we call initialize with invalid SSE data
|
164
|
+
# THEN it should raise an exception about no valid JSON
|
165
|
+
async with MCPHttpClient("https://test.server.com", {}) as client:
|
166
|
+
with pytest.raises(Exception, match="No valid JSON data found in SSE response"):
|
167
|
+
await client.initialize()
|
168
|
+
|
169
|
+
|
170
|
+
def test_mcp_service_hydrate_tool_definitions():
|
171
|
+
"""Test tool definition hydration with SSE responses"""
|
172
|
+
# GIVEN an MCP server configuration
|
173
|
+
sample_mcp_server = MCPServer(
|
174
|
+
name="test-server",
|
175
|
+
url="https://test.mcp.server.com/mcp",
|
176
|
+
authorization_type=AuthorizationType.BEARER_TOKEN,
|
177
|
+
bearer_token_value="test-token-123",
|
178
|
+
)
|
179
|
+
|
180
|
+
# AND a mock MCP service that returns tools via SSE
|
181
|
+
with mock.patch("vellum.workflows.integrations.mcp_service.asyncio.run") as mock_run:
|
182
|
+
mock_run.return_value = [
|
183
|
+
{
|
184
|
+
"name": "resolve-library-id",
|
185
|
+
"description": "Resolves library names to IDs",
|
186
|
+
"inputSchema": {
|
187
|
+
"type": "object",
|
188
|
+
"properties": {"libraryName": {"type": "string"}},
|
189
|
+
"required": ["libraryName"],
|
190
|
+
},
|
191
|
+
}
|
192
|
+
]
|
193
|
+
|
194
|
+
# WHEN we hydrate tool definitions
|
195
|
+
service = MCPService()
|
196
|
+
tool_definitions = service.hydrate_tool_definitions(sample_mcp_server)
|
197
|
+
|
198
|
+
# THEN we should get properly formatted MCPToolDefinition objects
|
199
|
+
assert len(tool_definitions) == 1
|
200
|
+
assert isinstance(tool_definitions[0], MCPToolDefinition)
|
201
|
+
assert tool_definitions[0].name == "resolve-library-id"
|
202
|
+
assert tool_definitions[0].description == "Resolves library names to IDs"
|
203
|
+
assert tool_definitions[0].server == sample_mcp_server
|
204
|
+
assert tool_definitions[0].parameters == {
|
205
|
+
"type": "object",
|
206
|
+
"properties": {"libraryName": {"type": "string"}},
|
207
|
+
"required": ["libraryName"],
|
208
|
+
}
|
209
|
+
|
210
|
+
|
211
|
+
def test_mcp_service_list_tools_handles_errors():
|
212
|
+
"""Test that SSE parsing errors are handled gracefully"""
|
213
|
+
# GIVEN an MCP server configuration
|
214
|
+
sample_mcp_server = MCPServer(name="test-server", url="https://test.mcp.server.com/mcp")
|
215
|
+
|
216
|
+
# AND a mock that raises an exception during SSE parsing
|
217
|
+
with mock.patch("vellum.workflows.integrations.mcp_service.asyncio.run") as mock_run:
|
218
|
+
mock_run.side_effect = Exception("SSE parsing failed")
|
219
|
+
|
220
|
+
# WHEN we try to list tools
|
221
|
+
service = MCPService()
|
222
|
+
tools = service.list_tools(sample_mcp_server)
|
223
|
+
|
224
|
+
# THEN we should get an empty list instead of crashing
|
225
|
+
assert tools == []
|
@@ -11,6 +11,7 @@ from vellum.workflows.nodes.displayable import (
|
|
11
11
|
PromptDeploymentNode,
|
12
12
|
SearchNode,
|
13
13
|
SubworkflowDeploymentNode,
|
14
|
+
WebSearchNode,
|
14
15
|
)
|
15
16
|
from vellum.workflows.nodes.displayable.bases import (
|
16
17
|
BaseInlinePromptNode as BaseInlinePromptNode,
|
@@ -43,4 +44,5 @@ __all__ = [
|
|
43
44
|
"PromptDeploymentNode",
|
44
45
|
"SearchNode",
|
45
46
|
"SubworkflowDeploymentNode",
|
47
|
+
"WebSearchNode",
|
46
48
|
]
|
@@ -14,6 +14,7 @@ from .prompt_deployment_node import PromptDeploymentNode
|
|
14
14
|
from .search_node import SearchNode
|
15
15
|
from .subworkflow_deployment_node import SubworkflowDeploymentNode
|
16
16
|
from .tool_calling_node import ToolCallingNode
|
17
|
+
from .web_search_node import WebSearchNode
|
17
18
|
|
18
19
|
__all__ = [
|
19
20
|
"APINode",
|
@@ -31,5 +32,6 @@ __all__ = [
|
|
31
32
|
"SearchNode",
|
32
33
|
"TemplatingNode",
|
33
34
|
"ToolCallingNode",
|
35
|
+
"WebSearchNode",
|
34
36
|
"FinalOutputNode",
|
35
37
|
]
|
@@ -0,0 +1,133 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, ClassVar, Dict, List, Optional
|
3
|
+
|
4
|
+
from requests import Request, RequestException, Session
|
5
|
+
from requests.exceptions import JSONDecodeError
|
6
|
+
|
7
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
8
|
+
from vellum.workflows.exceptions import NodeException
|
9
|
+
from vellum.workflows.nodes.bases import BaseNode
|
10
|
+
from vellum.workflows.outputs import BaseOutputs
|
11
|
+
from vellum.workflows.types.generics import StateType
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class WebSearchNode(BaseNode[StateType]):
|
17
|
+
"""
|
18
|
+
Used to perform web search using SerpAPI.
|
19
|
+
|
20
|
+
query: str - The search query to execute
|
21
|
+
api_key: str - SerpAPI authentication key
|
22
|
+
num_results: int - Number of search results to return (default: 10)
|
23
|
+
location: Optional[str] - Geographic location filter for search
|
24
|
+
"""
|
25
|
+
|
26
|
+
query: ClassVar[str] = ""
|
27
|
+
api_key: ClassVar[Optional[str]] = None
|
28
|
+
num_results: ClassVar[int] = 10
|
29
|
+
location: ClassVar[Optional[str]] = None
|
30
|
+
|
31
|
+
class Outputs(BaseOutputs):
|
32
|
+
"""
|
33
|
+
The outputs of the WebSearchNode.
|
34
|
+
|
35
|
+
text: str - Concatenated search result snippets with titles
|
36
|
+
urls: List[str] - List of URLs from search results
|
37
|
+
results: List[Dict[str, Any]] - Raw search results from SerpAPI
|
38
|
+
"""
|
39
|
+
|
40
|
+
text: str
|
41
|
+
urls: List[str]
|
42
|
+
results: List[Dict[str, Any]]
|
43
|
+
|
44
|
+
def _validate(self) -> None:
|
45
|
+
"""Validate node inputs."""
|
46
|
+
if not self.query or not isinstance(self.query, str) or not self.query.strip():
|
47
|
+
raise NodeException(
|
48
|
+
"Query is required and must be a non-empty string", code=WorkflowErrorCode.INVALID_INPUTS
|
49
|
+
)
|
50
|
+
|
51
|
+
if self.api_key is None:
|
52
|
+
raise NodeException("API key is required", code=WorkflowErrorCode.INVALID_INPUTS)
|
53
|
+
|
54
|
+
if not isinstance(self.num_results, int) or self.num_results <= 0:
|
55
|
+
raise NodeException("num_results must be a positive integer", code=WorkflowErrorCode.INVALID_INPUTS)
|
56
|
+
|
57
|
+
def run(self) -> Outputs:
|
58
|
+
"""Run the WebSearchNode to perform web search via SerpAPI."""
|
59
|
+
self._validate()
|
60
|
+
|
61
|
+
api_key_value = self.api_key
|
62
|
+
|
63
|
+
params = {
|
64
|
+
"q": self.query,
|
65
|
+
"api_key": api_key_value,
|
66
|
+
"num": self.num_results,
|
67
|
+
"engine": "google",
|
68
|
+
}
|
69
|
+
|
70
|
+
if self.location:
|
71
|
+
params["location"] = self.location
|
72
|
+
|
73
|
+
headers = {}
|
74
|
+
client_headers = self._context.vellum_client._client_wrapper.get_headers()
|
75
|
+
headers["User-Agent"] = client_headers.get("User-Agent")
|
76
|
+
|
77
|
+
try:
|
78
|
+
prepped = Request(method="GET", url="https://serpapi.com/search", params=params, headers=headers).prepare()
|
79
|
+
except Exception as e:
|
80
|
+
logger.exception("Failed to prepare SerpAPI request")
|
81
|
+
raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.PROVIDER_ERROR) from e
|
82
|
+
|
83
|
+
try:
|
84
|
+
with Session() as session:
|
85
|
+
response = session.send(prepped, timeout=30)
|
86
|
+
except RequestException as e:
|
87
|
+
logger.exception("SerpAPI request failed")
|
88
|
+
raise NodeException(f"HTTP request failed: {e}", code=WorkflowErrorCode.PROVIDER_ERROR) from e
|
89
|
+
|
90
|
+
if response.status_code == 401:
|
91
|
+
logger.error("SerpAPI authentication failed")
|
92
|
+
raise NodeException("Invalid API key", code=WorkflowErrorCode.INVALID_INPUTS)
|
93
|
+
elif response.status_code == 429:
|
94
|
+
logger.warning("SerpAPI rate limit exceeded")
|
95
|
+
raise NodeException("Rate limit exceeded", code=WorkflowErrorCode.PROVIDER_ERROR)
|
96
|
+
elif response.status_code >= 400:
|
97
|
+
logger.error(f"SerpAPI returned error status: {response.status_code}")
|
98
|
+
raise NodeException(f"SerpAPI error: HTTP {response.status_code}", code=WorkflowErrorCode.PROVIDER_ERROR)
|
99
|
+
|
100
|
+
try:
|
101
|
+
json_response = response.json()
|
102
|
+
except JSONDecodeError as e:
|
103
|
+
logger.exception("Failed to parse SerpAPI response as JSON")
|
104
|
+
raise NodeException(
|
105
|
+
f"Invalid JSON response from SerpAPI: {e}", code=WorkflowErrorCode.PROVIDER_ERROR
|
106
|
+
) from e
|
107
|
+
|
108
|
+
if "error" in json_response:
|
109
|
+
error_msg = json_response["error"]
|
110
|
+
logger.error(f"SerpAPI returned error: {error_msg}")
|
111
|
+
raise NodeException(f"SerpAPI error: {error_msg}", code=WorkflowErrorCode.PROVIDER_ERROR)
|
112
|
+
|
113
|
+
organic_results = json_response.get("organic_results", [])
|
114
|
+
|
115
|
+
text_results = []
|
116
|
+
urls = []
|
117
|
+
|
118
|
+
for result in organic_results:
|
119
|
+
title = result.get("title", "")
|
120
|
+
snippet = result.get("snippet", "")
|
121
|
+
link = result.get("link", "")
|
122
|
+
|
123
|
+
if title and snippet:
|
124
|
+
text_results.append(f"{title}: {snippet}")
|
125
|
+
elif title:
|
126
|
+
text_results.append(title)
|
127
|
+
elif snippet:
|
128
|
+
text_results.append(snippet)
|
129
|
+
|
130
|
+
if link:
|
131
|
+
urls.append(link)
|
132
|
+
|
133
|
+
return self.Outputs(text="\n\n".join(text_results), urls=urls, results=organic_results)
|
@@ -1,11 +1,25 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from
|
2
|
+
from uuid import UUID
|
3
|
+
from typing import TYPE_CHECKING, Iterator, Optional, Type, Union
|
3
4
|
|
4
5
|
from vellum.workflows.events.workflow import WorkflowEvent
|
6
|
+
from vellum.workflows.resolvers.types import LoadStateResult
|
5
7
|
from vellum.workflows.state.base import BaseState
|
6
8
|
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from vellum.workflows.state.context import WorkflowContext
|
11
|
+
from vellum.workflows.workflows.base import BaseWorkflow
|
12
|
+
|
7
13
|
|
8
14
|
class BaseWorkflowResolver(ABC):
|
15
|
+
def __init__(self):
|
16
|
+
self._context: Optional["WorkflowContext"] = None
|
17
|
+
self._workflow_class: Optional[Type["BaseWorkflow"]] = None
|
18
|
+
|
19
|
+
def register_workflow_instance(self, workflow_instance: "BaseWorkflow") -> None:
|
20
|
+
self._workflow_class = type(workflow_instance)
|
21
|
+
self._context = workflow_instance.context
|
22
|
+
|
9
23
|
@abstractmethod
|
10
24
|
def get_latest_execution_events(self) -> Iterator[WorkflowEvent]:
|
11
25
|
pass
|
@@ -13,3 +27,7 @@ class BaseWorkflowResolver(ABC):
|
|
13
27
|
@abstractmethod
|
14
28
|
def get_state_snapshot_history(self) -> Iterator[BaseState]:
|
15
29
|
pass
|
30
|
+
|
31
|
+
@abstractmethod
|
32
|
+
def load_state(self, previous_execution_id: Optional[Union[UUID, str]] = None) -> Optional[LoadStateResult]:
|
33
|
+
pass
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import logging
|
2
|
+
from uuid import UUID
|
3
|
+
from typing import Iterator, List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
from vellum.client.types.vellum_span import VellumSpan
|
6
|
+
from vellum.client.types.workflow_execution_initiated_event import WorkflowExecutionInitiatedEvent
|
7
|
+
from vellum.workflows.events.workflow import WorkflowEvent
|
8
|
+
from vellum.workflows.resolvers.base import BaseWorkflowResolver
|
9
|
+
from vellum.workflows.resolvers.types import LoadStateResult
|
10
|
+
from vellum.workflows.state.base import BaseState
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class VellumResolver(BaseWorkflowResolver):
|
16
|
+
def get_latest_execution_events(self) -> Iterator[WorkflowEvent]:
|
17
|
+
return iter([])
|
18
|
+
|
19
|
+
def get_state_snapshot_history(self) -> Iterator[BaseState]:
|
20
|
+
return iter([])
|
21
|
+
|
22
|
+
def _find_previous_and_root_span(
|
23
|
+
self, execution_id: str, spans: List[VellumSpan]
|
24
|
+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
25
|
+
previous_trace_id: Optional[str] = None
|
26
|
+
root_trace_id: Optional[str] = None
|
27
|
+
previous_span_id: Optional[str] = None
|
28
|
+
root_span_id: Optional[str] = None
|
29
|
+
|
30
|
+
for span in spans:
|
31
|
+
# Look for workflow execution spans with matching ID first
|
32
|
+
if span.name == "workflow.execution" and span.span_id == execution_id:
|
33
|
+
# Find the WorkflowExecutionInitiatedEvent in the span's events
|
34
|
+
initiated_event = next(
|
35
|
+
(event for event in span.events if isinstance(event, WorkflowExecutionInitiatedEvent)), None
|
36
|
+
)
|
37
|
+
if initiated_event:
|
38
|
+
previous_trace_id = initiated_event.trace_id
|
39
|
+
previous_span_id = initiated_event.span_id
|
40
|
+
links = initiated_event.links
|
41
|
+
if links:
|
42
|
+
root_span = next((link for link in links if link.type == "ROOT_SPAN"), None)
|
43
|
+
if root_span:
|
44
|
+
root_trace_id = root_span.trace_id
|
45
|
+
root_span_id = root_span.span_context.span_id
|
46
|
+
else:
|
47
|
+
# no links means this is the first execution
|
48
|
+
root_trace_id = initiated_event.trace_id
|
49
|
+
root_span_id = initiated_event.span_id
|
50
|
+
break
|
51
|
+
|
52
|
+
return previous_trace_id, root_trace_id, previous_span_id, root_span_id
|
53
|
+
|
54
|
+
def load_state(self, previous_execution_id: Optional[Union[UUID, str]] = None) -> Optional[LoadStateResult]:
|
55
|
+
if isinstance(previous_execution_id, UUID):
|
56
|
+
previous_execution_id = str(previous_execution_id)
|
57
|
+
|
58
|
+
if previous_execution_id is None:
|
59
|
+
return None
|
60
|
+
|
61
|
+
if not self._context:
|
62
|
+
logger.warning("Cannot load state: No workflow context registered")
|
63
|
+
return None
|
64
|
+
|
65
|
+
client = self._context.vellum_client
|
66
|
+
response = client.workflow_executions.retrieve_workflow_execution_detail(
|
67
|
+
execution_id=previous_execution_id,
|
68
|
+
)
|
69
|
+
|
70
|
+
if response.state is None:
|
71
|
+
return None
|
72
|
+
|
73
|
+
previous_trace_id, root_trace_id, previous_span_id, root_span_id = self._find_previous_and_root_span(
|
74
|
+
previous_execution_id, response.spans
|
75
|
+
)
|
76
|
+
|
77
|
+
if previous_trace_id is None or root_trace_id is None or previous_span_id is None or root_span_id is None:
|
78
|
+
logger.warning("Could not find required execution events for state loading")
|
79
|
+
return None
|
80
|
+
|
81
|
+
if "meta" in response.state:
|
82
|
+
response.state.pop("meta")
|
83
|
+
|
84
|
+
if self._workflow_class:
|
85
|
+
state_class = self._workflow_class.get_state_class()
|
86
|
+
state = state_class(**response.state)
|
87
|
+
else:
|
88
|
+
logger.warning("No workflow class registered, falling back to BaseState")
|
89
|
+
state = BaseState(**response.state)
|
90
|
+
|
91
|
+
return LoadStateResult(
|
92
|
+
state=state,
|
93
|
+
previous_trace_id=previous_trace_id,
|
94
|
+
previous_span_id=previous_span_id,
|
95
|
+
root_trace_id=root_trace_id,
|
96
|
+
root_span_id=root_span_id,
|
97
|
+
)
|