vellum-ai 1.2.5__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +8 -0
- vellum/client/README.md +1 -1
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/reference.md +0 -9
- vellum/client/resources/workflow_sandboxes/client.py +0 -12
- vellum/client/resources/workflow_sandboxes/raw_client.py +2 -10
- vellum/client/types/__init__.py +8 -0
- vellum/client/types/deployment_read.py +5 -5
- vellum/client/types/slim_deployment_read.py +5 -5
- vellum/client/types/slim_workflow_deployment.py +5 -5
- vellum/client/types/workflow_deployment_read.py +5 -5
- vellum/client/types/workflow_request_audio_input_request.py +30 -0
- vellum/client/types/workflow_request_document_input_request.py +30 -0
- vellum/client/types/workflow_request_image_input_request.py +30 -0
- vellum/client/types/workflow_request_input_request.py +8 -0
- vellum/client/types/workflow_request_video_input_request.py +30 -0
- vellum/types/workflow_request_audio_input_request.py +3 -0
- vellum/types/workflow_request_document_input_request.py +3 -0
- vellum/types/workflow_request_image_input_request.py +3 -0
- vellum/types/workflow_request_video_input_request.py +3 -0
- vellum/workflows/events/types.py +6 -1
- vellum/workflows/events/workflow.py +9 -2
- vellum/workflows/integrations/tests/test_mcp_service.py +106 -1
- vellum/workflows/nodes/__init__.py +2 -0
- vellum/workflows/nodes/displayable/__init__.py +2 -0
- vellum/workflows/nodes/displayable/tool_calling_node/utils.py +11 -0
- vellum/workflows/nodes/displayable/web_search_node/__init__.py +3 -0
- vellum/workflows/nodes/displayable/web_search_node/node.py +133 -0
- vellum/workflows/nodes/displayable/web_search_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/web_search_node/tests/test_node.py +319 -0
- vellum/workflows/resolvers/base.py +3 -2
- vellum/workflows/resolvers/resolver.py +62 -7
- vellum/workflows/resolvers/tests/test_resolver.py +79 -7
- vellum/workflows/resolvers/types.py +11 -0
- vellum/workflows/runner/runner.py +49 -1
- vellum/workflows/state/context.py +41 -7
- vellum/workflows/utils/zip.py +46 -0
- vellum/workflows/workflows/base.py +10 -0
- {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/METADATA +1 -1
- {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/RECORD +48 -34
- vellum_cli/tests/test_init.py +7 -24
- vellum_cli/tests/test_pull.py +27 -52
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +7 -33
- vellum_ee/workflows/display/utils/events.py +3 -0
- vellum_ee/workflows/tests/test_server.py +115 -0
- {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/LICENSE +0 -0
- {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/WHEEL +0 -0
- {vellum_ai-1.2.5.dist-info → vellum_ai-1.3.1.dist-info}/entry_points.txt +0 -0
@@ -11,6 +11,7 @@ from vellum.workflows.nodes.displayable import (
|
|
11
11
|
PromptDeploymentNode,
|
12
12
|
SearchNode,
|
13
13
|
SubworkflowDeploymentNode,
|
14
|
+
WebSearchNode,
|
14
15
|
)
|
15
16
|
from vellum.workflows.nodes.displayable.bases import (
|
16
17
|
BaseInlinePromptNode as BaseInlinePromptNode,
|
@@ -43,4 +44,5 @@ __all__ = [
|
|
43
44
|
"PromptDeploymentNode",
|
44
45
|
"SearchNode",
|
45
46
|
"SubworkflowDeploymentNode",
|
47
|
+
"WebSearchNode",
|
46
48
|
]
|
@@ -14,6 +14,7 @@ from .prompt_deployment_node import PromptDeploymentNode
|
|
14
14
|
from .search_node import SearchNode
|
15
15
|
from .subworkflow_deployment_node import SubworkflowDeploymentNode
|
16
16
|
from .tool_calling_node import ToolCallingNode
|
17
|
+
from .web_search_node import WebSearchNode
|
17
18
|
|
18
19
|
__all__ = [
|
19
20
|
"APINode",
|
@@ -31,5 +32,6 @@ __all__ = [
|
|
31
32
|
"SearchNode",
|
32
33
|
"TemplatingNode",
|
33
34
|
"ToolCallingNode",
|
35
|
+
"WebSearchNode",
|
34
36
|
"FinalOutputNode",
|
35
37
|
]
|
@@ -54,12 +54,23 @@ class FunctionCallNodeMixin:
|
|
54
54
|
return function_call.value.arguments or {}
|
55
55
|
return {}
|
56
56
|
|
57
|
+
def _extract_function_call_id(self) -> Optional[str]:
|
58
|
+
"""Extract function call ID from function call output."""
|
59
|
+
current_index = getattr(self, "state").current_prompt_output_index
|
60
|
+
if self.function_call_output and len(self.function_call_output) > current_index:
|
61
|
+
function_call = self.function_call_output[current_index]
|
62
|
+
if function_call.type == "FUNCTION_CALL" and function_call.value is not None:
|
63
|
+
return function_call.value.id
|
64
|
+
return None
|
65
|
+
|
57
66
|
def _add_function_result_to_chat_history(self, result: Any, state: ToolCallingState) -> None:
|
58
67
|
"""Add function execution result to chat history."""
|
68
|
+
function_call_id = self._extract_function_call_id()
|
59
69
|
state.chat_history.append(
|
60
70
|
ChatMessage(
|
61
71
|
role="FUNCTION",
|
62
72
|
content=StringChatMessageContent(value=json.dumps(result, cls=DefaultStateEncoder)),
|
73
|
+
source=function_call_id,
|
63
74
|
)
|
64
75
|
)
|
65
76
|
with state.__quiet__():
|
@@ -0,0 +1,133 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, ClassVar, Dict, List, Optional
|
3
|
+
|
4
|
+
from requests import Request, RequestException, Session
|
5
|
+
from requests.exceptions import JSONDecodeError
|
6
|
+
|
7
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
8
|
+
from vellum.workflows.exceptions import NodeException
|
9
|
+
from vellum.workflows.nodes.bases import BaseNode
|
10
|
+
from vellum.workflows.outputs import BaseOutputs
|
11
|
+
from vellum.workflows.types.generics import StateType
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class WebSearchNode(BaseNode[StateType]):
|
17
|
+
"""
|
18
|
+
Used to perform web search using SerpAPI.
|
19
|
+
|
20
|
+
query: str - The search query to execute
|
21
|
+
api_key: str - SerpAPI authentication key
|
22
|
+
num_results: int - Number of search results to return (default: 10)
|
23
|
+
location: Optional[str] - Geographic location filter for search
|
24
|
+
"""
|
25
|
+
|
26
|
+
query: ClassVar[str] = ""
|
27
|
+
api_key: ClassVar[Optional[str]] = None
|
28
|
+
num_results: ClassVar[int] = 10
|
29
|
+
location: ClassVar[Optional[str]] = None
|
30
|
+
|
31
|
+
class Outputs(BaseOutputs):
|
32
|
+
"""
|
33
|
+
The outputs of the WebSearchNode.
|
34
|
+
|
35
|
+
text: str - Concatenated search result snippets with titles
|
36
|
+
urls: List[str] - List of URLs from search results
|
37
|
+
results: List[Dict[str, Any]] - Raw search results from SerpAPI
|
38
|
+
"""
|
39
|
+
|
40
|
+
text: str
|
41
|
+
urls: List[str]
|
42
|
+
results: List[Dict[str, Any]]
|
43
|
+
|
44
|
+
def _validate(self) -> None:
|
45
|
+
"""Validate node inputs."""
|
46
|
+
if not self.query or not isinstance(self.query, str) or not self.query.strip():
|
47
|
+
raise NodeException(
|
48
|
+
"Query is required and must be a non-empty string", code=WorkflowErrorCode.INVALID_INPUTS
|
49
|
+
)
|
50
|
+
|
51
|
+
if self.api_key is None:
|
52
|
+
raise NodeException("API key is required", code=WorkflowErrorCode.INVALID_INPUTS)
|
53
|
+
|
54
|
+
if not isinstance(self.num_results, int) or self.num_results <= 0:
|
55
|
+
raise NodeException("num_results must be a positive integer", code=WorkflowErrorCode.INVALID_INPUTS)
|
56
|
+
|
57
|
+
def run(self) -> Outputs:
|
58
|
+
"""Run the WebSearchNode to perform web search via SerpAPI."""
|
59
|
+
self._validate()
|
60
|
+
|
61
|
+
api_key_value = self.api_key
|
62
|
+
|
63
|
+
params = {
|
64
|
+
"q": self.query,
|
65
|
+
"api_key": api_key_value,
|
66
|
+
"num": self.num_results,
|
67
|
+
"engine": "google",
|
68
|
+
}
|
69
|
+
|
70
|
+
if self.location:
|
71
|
+
params["location"] = self.location
|
72
|
+
|
73
|
+
headers = {}
|
74
|
+
client_headers = self._context.vellum_client._client_wrapper.get_headers()
|
75
|
+
headers["User-Agent"] = client_headers.get("User-Agent")
|
76
|
+
|
77
|
+
try:
|
78
|
+
prepped = Request(method="GET", url="https://serpapi.com/search", params=params, headers=headers).prepare()
|
79
|
+
except Exception as e:
|
80
|
+
logger.exception("Failed to prepare SerpAPI request")
|
81
|
+
raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.PROVIDER_ERROR) from e
|
82
|
+
|
83
|
+
try:
|
84
|
+
with Session() as session:
|
85
|
+
response = session.send(prepped, timeout=30)
|
86
|
+
except RequestException as e:
|
87
|
+
logger.exception("SerpAPI request failed")
|
88
|
+
raise NodeException(f"HTTP request failed: {e}", code=WorkflowErrorCode.PROVIDER_ERROR) from e
|
89
|
+
|
90
|
+
if response.status_code == 401:
|
91
|
+
logger.error("SerpAPI authentication failed")
|
92
|
+
raise NodeException("Invalid API key", code=WorkflowErrorCode.INVALID_INPUTS)
|
93
|
+
elif response.status_code == 429:
|
94
|
+
logger.warning("SerpAPI rate limit exceeded")
|
95
|
+
raise NodeException("Rate limit exceeded", code=WorkflowErrorCode.PROVIDER_ERROR)
|
96
|
+
elif response.status_code >= 400:
|
97
|
+
logger.error(f"SerpAPI returned error status: {response.status_code}")
|
98
|
+
raise NodeException(f"SerpAPI error: HTTP {response.status_code}", code=WorkflowErrorCode.PROVIDER_ERROR)
|
99
|
+
|
100
|
+
try:
|
101
|
+
json_response = response.json()
|
102
|
+
except JSONDecodeError as e:
|
103
|
+
logger.exception("Failed to parse SerpAPI response as JSON")
|
104
|
+
raise NodeException(
|
105
|
+
f"Invalid JSON response from SerpAPI: {e}", code=WorkflowErrorCode.PROVIDER_ERROR
|
106
|
+
) from e
|
107
|
+
|
108
|
+
if "error" in json_response:
|
109
|
+
error_msg = json_response["error"]
|
110
|
+
logger.error(f"SerpAPI returned error: {error_msg}")
|
111
|
+
raise NodeException(f"SerpAPI error: {error_msg}", code=WorkflowErrorCode.PROVIDER_ERROR)
|
112
|
+
|
113
|
+
organic_results = json_response.get("organic_results", [])
|
114
|
+
|
115
|
+
text_results = []
|
116
|
+
urls = []
|
117
|
+
|
118
|
+
for result in organic_results:
|
119
|
+
title = result.get("title", "")
|
120
|
+
snippet = result.get("snippet", "")
|
121
|
+
link = result.get("link", "")
|
122
|
+
|
123
|
+
if title and snippet:
|
124
|
+
text_results.append(f"{title}: {snippet}")
|
125
|
+
elif title:
|
126
|
+
text_results.append(title)
|
127
|
+
elif snippet:
|
128
|
+
text_results.append(snippet)
|
129
|
+
|
130
|
+
if link:
|
131
|
+
urls.append(link)
|
132
|
+
|
133
|
+
return self.Outputs(text="\n\n".join(text_results), urls=urls, results=organic_results)
|
File without changes
|
@@ -0,0 +1,319 @@
|
|
1
|
+
import pytest
|
2
|
+
from unittest.mock import MagicMock
|
3
|
+
|
4
|
+
import requests
|
5
|
+
|
6
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
7
|
+
from vellum.workflows.exceptions import NodeException
|
8
|
+
from vellum.workflows.inputs import BaseInputs
|
9
|
+
from vellum.workflows.state import BaseState
|
10
|
+
from vellum.workflows.state.base import StateMeta
|
11
|
+
|
12
|
+
from ..node import WebSearchNode
|
13
|
+
|
14
|
+
|
15
|
+
@pytest.fixture
|
16
|
+
def base_node_setup(vellum_client):
|
17
|
+
"""Basic node setup with required inputs."""
|
18
|
+
|
19
|
+
class Inputs(BaseInputs):
|
20
|
+
query: str
|
21
|
+
api_key: str
|
22
|
+
num_results: int
|
23
|
+
|
24
|
+
class State(BaseState):
|
25
|
+
pass
|
26
|
+
|
27
|
+
class TestableWebSearchNode(WebSearchNode):
|
28
|
+
query = Inputs.query
|
29
|
+
api_key = Inputs.api_key
|
30
|
+
num_results = Inputs.num_results
|
31
|
+
|
32
|
+
state = State(meta=StateMeta(workflow_inputs=Inputs(query="test query", api_key="test_api_key", num_results=3)))
|
33
|
+
context = MagicMock()
|
34
|
+
context.vellum_client = vellum_client
|
35
|
+
node = TestableWebSearchNode(state=state, context=context)
|
36
|
+
return node
|
37
|
+
|
38
|
+
|
39
|
+
def test_successful_search_with_results(base_node_setup, requests_mock):
|
40
|
+
"""Test successful SerpAPI search with typical organic results."""
|
41
|
+
# GIVEN a mock SerpAPI response with organic results
|
42
|
+
mock_response = {
|
43
|
+
"organic_results": [
|
44
|
+
{
|
45
|
+
"title": "First Result",
|
46
|
+
"snippet": "This is the first search result snippet",
|
47
|
+
"link": "https://example1.com",
|
48
|
+
"position": 1,
|
49
|
+
},
|
50
|
+
{
|
51
|
+
"title": "Second Result",
|
52
|
+
"snippet": "This is the second search result snippet",
|
53
|
+
"link": "https://example2.com",
|
54
|
+
"position": 2,
|
55
|
+
},
|
56
|
+
{
|
57
|
+
"title": "Third Result",
|
58
|
+
"snippet": "This is the third search result snippet",
|
59
|
+
"link": "https://example3.com",
|
60
|
+
"position": 3,
|
61
|
+
},
|
62
|
+
]
|
63
|
+
}
|
64
|
+
|
65
|
+
requests_mock.get("https://serpapi.com/search", json=mock_response)
|
66
|
+
|
67
|
+
# WHEN we run the node
|
68
|
+
outputs = base_node_setup.run()
|
69
|
+
|
70
|
+
# THEN the text output should be properly formatted
|
71
|
+
expected_text = (
|
72
|
+
"First Result: This is the first search result snippet\n\n"
|
73
|
+
"Second Result: This is the second search result snippet\n\n"
|
74
|
+
"Third Result: This is the third search result snippet"
|
75
|
+
)
|
76
|
+
assert outputs.text == expected_text
|
77
|
+
|
78
|
+
# AND URLs should be extracted correctly
|
79
|
+
assert outputs.urls == ["https://example1.com", "https://example2.com", "https://example3.com"]
|
80
|
+
|
81
|
+
# AND raw results should be preserved
|
82
|
+
assert outputs.results == mock_response["organic_results"]
|
83
|
+
|
84
|
+
# AND the request should have the correct parameters
|
85
|
+
assert requests_mock.last_request.qs == {
|
86
|
+
"q": ["test query"],
|
87
|
+
"api_key": ["test_api_key"],
|
88
|
+
"num": ["3"],
|
89
|
+
"engine": ["google"],
|
90
|
+
}
|
91
|
+
|
92
|
+
|
93
|
+
def test_search_with_location_parameter(base_node_setup, requests_mock):
|
94
|
+
"""Test that location parameter is properly passed to SerpAPI."""
|
95
|
+
# GIVEN a location parameter is set
|
96
|
+
base_node_setup.location = "New York, NY"
|
97
|
+
|
98
|
+
requests_mock.get("https://serpapi.com/search", json={"organic_results": []})
|
99
|
+
|
100
|
+
# WHEN we run the node
|
101
|
+
base_node_setup.run()
|
102
|
+
|
103
|
+
# THEN the location parameter should be included (URL encoding may lowercase)
|
104
|
+
assert "location" in requests_mock.last_request.qs
|
105
|
+
assert requests_mock.last_request.qs["location"][0].lower() == "new york, ny"
|
106
|
+
|
107
|
+
|
108
|
+
def test_authentication_error_401(base_node_setup, requests_mock):
|
109
|
+
"""Test 401 authentication error raises NodeException with INVALID_INPUTS."""
|
110
|
+
# GIVEN SerpAPI returns a 401 authentication error
|
111
|
+
requests_mock.get("https://serpapi.com/search", status_code=401)
|
112
|
+
|
113
|
+
# WHEN we run the node
|
114
|
+
with pytest.raises(NodeException) as exc_info:
|
115
|
+
base_node_setup.run()
|
116
|
+
|
117
|
+
# THEN it should raise the appropriate error
|
118
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
|
119
|
+
assert "Invalid API key" in str(exc_info.value)
|
120
|
+
|
121
|
+
|
122
|
+
def test_rate_limit_error_429(base_node_setup, requests_mock):
|
123
|
+
"""Test 429 rate limit error raises NodeException with PROVIDER_ERROR."""
|
124
|
+
# GIVEN SerpAPI returns a 429 rate limit error
|
125
|
+
requests_mock.get("https://serpapi.com/search", status_code=429)
|
126
|
+
|
127
|
+
# WHEN we run the node
|
128
|
+
with pytest.raises(NodeException) as exc_info:
|
129
|
+
base_node_setup.run()
|
130
|
+
|
131
|
+
# THEN it should raise the appropriate error
|
132
|
+
assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
|
133
|
+
assert "Rate limit exceeded" in str(exc_info.value)
|
134
|
+
|
135
|
+
|
136
|
+
def test_server_error_500(base_node_setup, requests_mock):
|
137
|
+
"""Test 500+ server errors raise NodeException with PROVIDER_ERROR."""
|
138
|
+
# GIVEN SerpAPI returns a 500 server error
|
139
|
+
requests_mock.get("https://serpapi.com/search", status_code=500)
|
140
|
+
|
141
|
+
# WHEN we run the node
|
142
|
+
with pytest.raises(NodeException) as exc_info:
|
143
|
+
base_node_setup.run()
|
144
|
+
|
145
|
+
# THEN it should raise the appropriate error
|
146
|
+
assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
|
147
|
+
assert "SerpAPI error: HTTP 500" in str(exc_info.value)
|
148
|
+
|
149
|
+
|
150
|
+
def test_invalid_json_response(base_node_setup, requests_mock):
|
151
|
+
"""Test non-JSON response raises appropriate NodeException."""
|
152
|
+
# GIVEN SerpAPI returns non-JSON content
|
153
|
+
requests_mock.get("https://serpapi.com/search", text="Not JSON")
|
154
|
+
|
155
|
+
# WHEN we run the node
|
156
|
+
with pytest.raises(NodeException) as exc_info:
|
157
|
+
base_node_setup.run()
|
158
|
+
|
159
|
+
# THEN it should raise the appropriate error
|
160
|
+
assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
|
161
|
+
assert "Invalid JSON response" in str(exc_info.value)
|
162
|
+
|
163
|
+
|
164
|
+
def test_serpapi_error_in_response(base_node_setup, requests_mock):
|
165
|
+
"""Test SerpAPI error field in response raises NodeException."""
|
166
|
+
# GIVEN SerpAPI returns an error in the response body
|
167
|
+
requests_mock.get("https://serpapi.com/search", json={"error": "Invalid search parameters"})
|
168
|
+
|
169
|
+
# WHEN we run the node
|
170
|
+
with pytest.raises(NodeException) as exc_info:
|
171
|
+
base_node_setup.run()
|
172
|
+
|
173
|
+
# THEN it should raise the appropriate error
|
174
|
+
assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
|
175
|
+
assert "Invalid search parameters" in str(exc_info.value)
|
176
|
+
|
177
|
+
|
178
|
+
def test_empty_query_validation(vellum_client):
|
179
|
+
"""Test empty query raises validation error."""
|
180
|
+
|
181
|
+
# GIVEN a node with an empty query
|
182
|
+
class TestNode(WebSearchNode):
|
183
|
+
query = ""
|
184
|
+
api_key = "test_key"
|
185
|
+
num_results = 10
|
186
|
+
|
187
|
+
context = MagicMock()
|
188
|
+
context.vellum_client = vellum_client
|
189
|
+
node = TestNode(state=BaseState(meta=StateMeta(workflow_inputs=BaseInputs())), context=context)
|
190
|
+
|
191
|
+
# WHEN we run the node
|
192
|
+
with pytest.raises(NodeException) as exc_info:
|
193
|
+
node.run()
|
194
|
+
|
195
|
+
# THEN it should raise a validation error
|
196
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
|
197
|
+
assert "Query is required" in str(exc_info.value)
|
198
|
+
|
199
|
+
|
200
|
+
def test_missing_api_key_validation(vellum_client):
|
201
|
+
"""Test missing API key raises validation error."""
|
202
|
+
|
203
|
+
# GIVEN a node with no API key
|
204
|
+
class TestNode(WebSearchNode):
|
205
|
+
query = "test query"
|
206
|
+
api_key = None
|
207
|
+
num_results = 10
|
208
|
+
|
209
|
+
context = MagicMock()
|
210
|
+
context.vellum_client = vellum_client
|
211
|
+
node = TestNode(state=BaseState(meta=StateMeta(workflow_inputs=BaseInputs())), context=context)
|
212
|
+
|
213
|
+
# WHEN we run the node
|
214
|
+
with pytest.raises(NodeException) as exc_info:
|
215
|
+
node.run()
|
216
|
+
|
217
|
+
# THEN it should raise a validation error
|
218
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
|
219
|
+
assert "API key is required" in str(exc_info.value)
|
220
|
+
|
221
|
+
|
222
|
+
def test_invalid_num_results_validation(vellum_client):
|
223
|
+
"""Test invalid num_results raises validation error."""
|
224
|
+
|
225
|
+
# GIVEN a node with invalid num_results
|
226
|
+
class TestNode(WebSearchNode):
|
227
|
+
query = "test query"
|
228
|
+
api_key = "test_key"
|
229
|
+
num_results = -1
|
230
|
+
|
231
|
+
context = MagicMock()
|
232
|
+
context.vellum_client = vellum_client
|
233
|
+
node = TestNode(state=BaseState(meta=StateMeta(workflow_inputs=BaseInputs())), context=context)
|
234
|
+
|
235
|
+
# WHEN we run the node
|
236
|
+
with pytest.raises(NodeException) as exc_info:
|
237
|
+
node.run()
|
238
|
+
|
239
|
+
# THEN it should raise a validation error
|
240
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
|
241
|
+
assert "num_results must be a positive integer" in str(exc_info.value)
|
242
|
+
|
243
|
+
|
244
|
+
def test_empty_organic_results(base_node_setup, requests_mock):
|
245
|
+
"""Test handling of empty search results."""
|
246
|
+
# GIVEN SerpAPI returns no organic results
|
247
|
+
requests_mock.get("https://serpapi.com/search", json={"organic_results": []})
|
248
|
+
|
249
|
+
# WHEN we run the node
|
250
|
+
outputs = base_node_setup.run()
|
251
|
+
|
252
|
+
# THEN all outputs should be empty
|
253
|
+
assert outputs.text == ""
|
254
|
+
assert outputs.urls == []
|
255
|
+
assert outputs.results == []
|
256
|
+
|
257
|
+
|
258
|
+
def test_missing_fields_in_results(base_node_setup, requests_mock):
|
259
|
+
"""Test handling of missing title, snippet, or link fields."""
|
260
|
+
# GIVEN SerpAPI returns results with missing fields
|
261
|
+
mock_response = {
|
262
|
+
"organic_results": [
|
263
|
+
{
|
264
|
+
"title": "Only Title",
|
265
|
+
# Missing snippet and link
|
266
|
+
},
|
267
|
+
{
|
268
|
+
"snippet": "Only snippet, no title or link"
|
269
|
+
# Missing title and link
|
270
|
+
},
|
271
|
+
{
|
272
|
+
"title": "Title with link",
|
273
|
+
"link": "https://example.com",
|
274
|
+
# Missing snippet
|
275
|
+
},
|
276
|
+
{
|
277
|
+
# All fields missing - should be skipped
|
278
|
+
"position": 4
|
279
|
+
},
|
280
|
+
]
|
281
|
+
}
|
282
|
+
|
283
|
+
requests_mock.get("https://serpapi.com/search", json=mock_response)
|
284
|
+
|
285
|
+
# WHEN we run the node
|
286
|
+
outputs = base_node_setup.run()
|
287
|
+
|
288
|
+
# THEN text should handle missing fields gracefully
|
289
|
+
expected_text = "Only Title\n\n" "Only snippet, no title or link\n\n" "Title with link"
|
290
|
+
assert outputs.text == expected_text
|
291
|
+
|
292
|
+
# AND URLs should only include valid links
|
293
|
+
assert outputs.urls == ["https://example.com"]
|
294
|
+
|
295
|
+
|
296
|
+
def test_request_timeout_handling(base_node_setup, requests_mock):
|
297
|
+
"""Test network timeout raises appropriate error."""
|
298
|
+
# GIVEN a network timeout occurs
|
299
|
+
requests_mock.get("https://serpapi.com/search", exc=requests.exceptions.Timeout("Connection timed out"))
|
300
|
+
|
301
|
+
# WHEN we run the node
|
302
|
+
with pytest.raises(NodeException) as exc_info:
|
303
|
+
base_node_setup.run()
|
304
|
+
|
305
|
+
# THEN it should raise a provider error
|
306
|
+
assert exc_info.value.code == WorkflowErrorCode.PROVIDER_ERROR
|
307
|
+
assert "HTTP request failed" in str(exc_info.value)
|
308
|
+
|
309
|
+
|
310
|
+
def test_user_agent_header_included(base_node_setup, requests_mock):
|
311
|
+
"""Test that User-Agent header from vellum_client is included."""
|
312
|
+
# GIVEN a successful request
|
313
|
+
requests_mock.get("https://serpapi.com/search", json={"organic_results": []})
|
314
|
+
|
315
|
+
# WHEN we run the node
|
316
|
+
base_node_setup.run()
|
317
|
+
|
318
|
+
# THEN the User-Agent header should be included
|
319
|
+
assert requests_mock.last_request.headers["User-Agent"] == "vellum-python-sdk/1.0.0"
|
@@ -1,8 +1,9 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from uuid import UUID
|
3
|
-
from typing import TYPE_CHECKING, Iterator, Optional, Type
|
3
|
+
from typing import TYPE_CHECKING, Iterator, Optional, Type, Union
|
4
4
|
|
5
5
|
from vellum.workflows.events.workflow import WorkflowEvent
|
6
|
+
from vellum.workflows.resolvers.types import LoadStateResult
|
6
7
|
from vellum.workflows.state.base import BaseState
|
7
8
|
|
8
9
|
if TYPE_CHECKING:
|
@@ -28,5 +29,5 @@ class BaseWorkflowResolver(ABC):
|
|
28
29
|
pass
|
29
30
|
|
30
31
|
@abstractmethod
|
31
|
-
def load_state(self, previous_execution_id: Optional[UUID] = None) -> Optional[
|
32
|
+
def load_state(self, previous_execution_id: Optional[Union[UUID, str]] = None) -> Optional[LoadStateResult]:
|
32
33
|
pass
|
@@ -1,10 +1,13 @@
|
|
1
1
|
import logging
|
2
2
|
from uuid import UUID
|
3
|
-
from typing import Iterator, Optional
|
3
|
+
from typing import Iterator, List, Optional, Tuple, Union
|
4
4
|
|
5
|
+
from vellum.client.types.vellum_span import VellumSpan
|
6
|
+
from vellum.client.types.workflow_execution_initiated_event import WorkflowExecutionInitiatedEvent
|
5
7
|
from vellum.workflows.events.workflow import WorkflowEvent
|
6
8
|
from vellum.workflows.resolvers.base import BaseWorkflowResolver
|
7
|
-
from vellum.workflows.
|
9
|
+
from vellum.workflows.resolvers.types import LoadStateResult
|
10
|
+
from vellum.workflows.state.base import BaseState
|
8
11
|
|
9
12
|
logger = logging.getLogger(__name__)
|
10
13
|
|
@@ -16,7 +19,42 @@ class VellumResolver(BaseWorkflowResolver):
|
|
16
19
|
def get_state_snapshot_history(self) -> Iterator[BaseState]:
|
17
20
|
return iter([])
|
18
21
|
|
19
|
-
def
|
22
|
+
def _find_previous_and_root_span(
|
23
|
+
self, execution_id: str, spans: List[VellumSpan]
|
24
|
+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
25
|
+
previous_trace_id: Optional[str] = None
|
26
|
+
root_trace_id: Optional[str] = None
|
27
|
+
previous_span_id: Optional[str] = None
|
28
|
+
root_span_id: Optional[str] = None
|
29
|
+
|
30
|
+
for span in spans:
|
31
|
+
# Look for workflow execution spans with matching ID first
|
32
|
+
if span.name == "workflow.execution" and span.span_id == execution_id:
|
33
|
+
# Find the WorkflowExecutionInitiatedEvent in the span's events
|
34
|
+
initiated_event = next(
|
35
|
+
(event for event in span.events if isinstance(event, WorkflowExecutionInitiatedEvent)), None
|
36
|
+
)
|
37
|
+
if initiated_event:
|
38
|
+
previous_trace_id = initiated_event.trace_id
|
39
|
+
previous_span_id = initiated_event.span_id
|
40
|
+
links = initiated_event.links
|
41
|
+
if links:
|
42
|
+
root_span = next((link for link in links if link.type == "ROOT_SPAN"), None)
|
43
|
+
if root_span:
|
44
|
+
root_trace_id = root_span.trace_id
|
45
|
+
root_span_id = root_span.span_context.span_id
|
46
|
+
else:
|
47
|
+
# no links means this is the first execution
|
48
|
+
root_trace_id = initiated_event.trace_id
|
49
|
+
root_span_id = initiated_event.span_id
|
50
|
+
break
|
51
|
+
|
52
|
+
return previous_trace_id, root_trace_id, previous_span_id, root_span_id
|
53
|
+
|
54
|
+
def load_state(self, previous_execution_id: Optional[Union[UUID, str]] = None) -> Optional[LoadStateResult]:
|
55
|
+
if isinstance(previous_execution_id, UUID):
|
56
|
+
previous_execution_id = str(previous_execution_id)
|
57
|
+
|
20
58
|
if previous_execution_id is None:
|
21
59
|
return None
|
22
60
|
|
@@ -26,17 +64,34 @@ class VellumResolver(BaseWorkflowResolver):
|
|
26
64
|
|
27
65
|
client = self._context.vellum_client
|
28
66
|
response = client.workflow_executions.retrieve_workflow_execution_detail(
|
29
|
-
execution_id=
|
67
|
+
execution_id=previous_execution_id,
|
30
68
|
)
|
31
69
|
|
32
70
|
if response.state is None:
|
33
71
|
return None
|
34
72
|
|
35
|
-
|
73
|
+
previous_trace_id, root_trace_id, previous_span_id, root_span_id = self._find_previous_and_root_span(
|
74
|
+
previous_execution_id, response.spans
|
75
|
+
)
|
76
|
+
|
77
|
+
if previous_trace_id is None or root_trace_id is None or previous_span_id is None or root_span_id is None:
|
78
|
+
logger.warning("Could not find required execution events for state loading")
|
79
|
+
return None
|
80
|
+
|
81
|
+
if "meta" in response.state:
|
82
|
+
response.state.pop("meta")
|
36
83
|
|
37
84
|
if self._workflow_class:
|
38
85
|
state_class = self._workflow_class.get_state_class()
|
39
|
-
|
86
|
+
state = state_class(**response.state)
|
40
87
|
else:
|
41
88
|
logger.warning("No workflow class registered, falling back to BaseState")
|
42
|
-
|
89
|
+
state = BaseState(**response.state)
|
90
|
+
|
91
|
+
return LoadStateResult(
|
92
|
+
state=state,
|
93
|
+
previous_trace_id=previous_trace_id,
|
94
|
+
previous_span_id=previous_span_id,
|
95
|
+
root_trace_id=root_trace_id,
|
96
|
+
root_span_id=root_span_id,
|
97
|
+
)
|