vellum-ai 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +0 -6
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/types/__init__.py +0 -6
- vellum/client/types/organization_read.py +1 -2
- vellum/workflows/events/context.py +111 -0
- vellum/workflows/integrations/__init__.py +0 -0
- vellum/workflows/integrations/composio_service.py +138 -0
- vellum/workflows/nodes/displayable/bases/api_node/node.py +27 -9
- vellum/workflows/nodes/displayable/bases/api_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/bases/api_node/tests/test_node.py +47 -0
- vellum/workflows/nodes/displayable/tool_calling_node/tests/test_composio_service.py +63 -58
- vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +21 -1
- vellum/workflows/nodes/displayable/tool_calling_node/utils.py +124 -59
- vellum/workflows/types/definition.py +4 -2
- vellum/workflows/utils/functions.py +13 -1
- vellum/workflows/utils/tests/test_functions.py +32 -1
- {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/METADATA +1 -3
- {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/RECORD +26 -27
- vellum_cli/push.py +11 -2
- vellum_cli/tests/test_push.py +57 -1
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -0
- vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +16 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_composio_serialization.py +89 -0
- vellum/client/types/organization_limit_config.py +0 -25
- vellum/client/types/quota.py +0 -21
- vellum/client/types/vembda_service_tier_enum.py +0 -5
- vellum/types/organization_limit_config.py +0 -3
- vellum/types/quota.py +0 -3
- vellum/types/vembda_service_tier_enum.py +0 -3
- vellum/workflows/nodes/displayable/tool_calling_node/composio_service.py +0 -83
- {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/LICENSE +0 -0
- {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/WHEEL +0 -0
- {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/entry_points.txt +0 -0
vellum/__init__.py
CHANGED
@@ -288,7 +288,6 @@ from .client.types import (
|
|
288
288
|
OpenAiVectorizerTextEmbedding3SmallRequest,
|
289
289
|
OpenAiVectorizerTextEmbeddingAda002,
|
290
290
|
OpenAiVectorizerTextEmbeddingAda002Request,
|
291
|
-
OrganizationLimitConfig,
|
292
291
|
OrganizationRead,
|
293
292
|
PaginatedContainerImageReadList,
|
294
293
|
PaginatedDeploymentReleaseTagReadList,
|
@@ -329,7 +328,6 @@ from .client.types import (
|
|
329
328
|
PromptRequestStringInput,
|
330
329
|
PromptSettings,
|
331
330
|
PromptVersionBuildConfigSandbox,
|
332
|
-
Quota,
|
333
331
|
RawPromptExecutionOverridesRequest,
|
334
332
|
ReductoChunkerConfig,
|
335
333
|
ReductoChunkerConfigRequest,
|
@@ -527,7 +525,6 @@ from .client.types import (
|
|
527
525
|
VellumVariableExtensions,
|
528
526
|
VellumVariableType,
|
529
527
|
VellumWorkflowExecutionEvent,
|
530
|
-
VembdaServiceTierEnum,
|
531
528
|
WorkflowDeploymentEventExecutionsResponse,
|
532
529
|
WorkflowDeploymentHistoryItem,
|
533
530
|
WorkflowDeploymentParentContext,
|
@@ -939,7 +936,6 @@ __all__ = [
|
|
939
936
|
"OpenAiVectorizerTextEmbedding3SmallRequest",
|
940
937
|
"OpenAiVectorizerTextEmbeddingAda002",
|
941
938
|
"OpenAiVectorizerTextEmbeddingAda002Request",
|
942
|
-
"OrganizationLimitConfig",
|
943
939
|
"OrganizationRead",
|
944
940
|
"PaginatedContainerImageReadList",
|
945
941
|
"PaginatedDeploymentReleaseTagReadList",
|
@@ -980,7 +976,6 @@ __all__ = [
|
|
980
976
|
"PromptRequestStringInput",
|
981
977
|
"PromptSettings",
|
982
978
|
"PromptVersionBuildConfigSandbox",
|
983
|
-
"Quota",
|
984
979
|
"RawPromptExecutionOverridesRequest",
|
985
980
|
"ReductoChunkerConfig",
|
986
981
|
"ReductoChunkerConfigRequest",
|
@@ -1180,7 +1175,6 @@ __all__ = [
|
|
1180
1175
|
"VellumVariableExtensions",
|
1181
1176
|
"VellumVariableType",
|
1182
1177
|
"VellumWorkflowExecutionEvent",
|
1183
|
-
"VembdaServiceTierEnum",
|
1184
1178
|
"WorkflowDeploymentEventExecutionsResponse",
|
1185
1179
|
"WorkflowDeploymentHistoryItem",
|
1186
1180
|
"WorkflowDeploymentParentContext",
|
@@ -25,10 +25,10 @@ class BaseClientWrapper:
|
|
25
25
|
|
26
26
|
def get_headers(self) -> typing.Dict[str, str]:
|
27
27
|
headers: typing.Dict[str, str] = {
|
28
|
-
"User-Agent": "vellum-ai/1.0.
|
28
|
+
"User-Agent": "vellum-ai/1.0.6",
|
29
29
|
"X-Fern-Language": "Python",
|
30
30
|
"X-Fern-SDK-Name": "vellum-ai",
|
31
|
-
"X-Fern-SDK-Version": "1.0.
|
31
|
+
"X-Fern-SDK-Version": "1.0.6",
|
32
32
|
}
|
33
33
|
if self._api_version is not None:
|
34
34
|
headers["X-API-Version"] = self._api_version
|
vellum/client/types/__init__.py
CHANGED
@@ -296,7 +296,6 @@ from .open_ai_vectorizer_text_embedding_3_small import OpenAiVectorizerTextEmbed
|
|
296
296
|
from .open_ai_vectorizer_text_embedding_3_small_request import OpenAiVectorizerTextEmbedding3SmallRequest
|
297
297
|
from .open_ai_vectorizer_text_embedding_ada_002 import OpenAiVectorizerTextEmbeddingAda002
|
298
298
|
from .open_ai_vectorizer_text_embedding_ada_002_request import OpenAiVectorizerTextEmbeddingAda002Request
|
299
|
-
from .organization_limit_config import OrganizationLimitConfig
|
300
299
|
from .organization_read import OrganizationRead
|
301
300
|
from .paginated_container_image_read_list import PaginatedContainerImageReadList
|
302
301
|
from .paginated_deployment_release_tag_read_list import PaginatedDeploymentReleaseTagReadList
|
@@ -337,7 +336,6 @@ from .prompt_request_json_input import PromptRequestJsonInput
|
|
337
336
|
from .prompt_request_string_input import PromptRequestStringInput
|
338
337
|
from .prompt_settings import PromptSettings
|
339
338
|
from .prompt_version_build_config_sandbox import PromptVersionBuildConfigSandbox
|
340
|
-
from .quota import Quota
|
341
339
|
from .raw_prompt_execution_overrides_request import RawPromptExecutionOverridesRequest
|
342
340
|
from .reducto_chunker_config import ReductoChunkerConfig
|
343
341
|
from .reducto_chunker_config_request import ReductoChunkerConfigRequest
|
@@ -551,7 +549,6 @@ from .vellum_variable import VellumVariable
|
|
551
549
|
from .vellum_variable_extensions import VellumVariableExtensions
|
552
550
|
from .vellum_variable_type import VellumVariableType
|
553
551
|
from .vellum_workflow_execution_event import VellumWorkflowExecutionEvent
|
554
|
-
from .vembda_service_tier_enum import VembdaServiceTierEnum
|
555
552
|
from .workflow_deployment_event_executions_response import WorkflowDeploymentEventExecutionsResponse
|
556
553
|
from .workflow_deployment_history_item import WorkflowDeploymentHistoryItem
|
557
554
|
from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
|
@@ -918,7 +915,6 @@ __all__ = [
|
|
918
915
|
"OpenAiVectorizerTextEmbedding3SmallRequest",
|
919
916
|
"OpenAiVectorizerTextEmbeddingAda002",
|
920
917
|
"OpenAiVectorizerTextEmbeddingAda002Request",
|
921
|
-
"OrganizationLimitConfig",
|
922
918
|
"OrganizationRead",
|
923
919
|
"PaginatedContainerImageReadList",
|
924
920
|
"PaginatedDeploymentReleaseTagReadList",
|
@@ -959,7 +955,6 @@ __all__ = [
|
|
959
955
|
"PromptRequestStringInput",
|
960
956
|
"PromptSettings",
|
961
957
|
"PromptVersionBuildConfigSandbox",
|
962
|
-
"Quota",
|
963
958
|
"RawPromptExecutionOverridesRequest",
|
964
959
|
"ReductoChunkerConfig",
|
965
960
|
"ReductoChunkerConfigRequest",
|
@@ -1157,7 +1152,6 @@ __all__ = [
|
|
1157
1152
|
"VellumVariableExtensions",
|
1158
1153
|
"VellumVariableType",
|
1159
1154
|
"VellumWorkflowExecutionEvent",
|
1160
|
-
"VembdaServiceTierEnum",
|
1161
1155
|
"WorkflowDeploymentEventExecutionsResponse",
|
1162
1156
|
"WorkflowDeploymentHistoryItem",
|
1163
1157
|
"WorkflowDeploymentParentContext",
|
@@ -3,7 +3,6 @@
|
|
3
3
|
from ..core.pydantic_utilities import UniversalBaseModel
|
4
4
|
import typing
|
5
5
|
from .new_member_join_behavior_enum import NewMemberJoinBehaviorEnum
|
6
|
-
from .organization_limit_config import OrganizationLimitConfig
|
7
6
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
7
|
import pydantic
|
9
8
|
|
@@ -13,7 +12,7 @@ class OrganizationRead(UniversalBaseModel):
|
|
13
12
|
name: str
|
14
13
|
allow_staff_access: typing.Optional[bool] = None
|
15
14
|
new_member_join_behavior: NewMemberJoinBehaviorEnum
|
16
|
-
limit_config:
|
15
|
+
limit_config: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = None
|
17
16
|
|
18
17
|
if IS_PYDANTIC_V2:
|
19
18
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -0,0 +1,111 @@
|
|
1
|
+
"""Monitoring execution context for workflow tracing."""
|
2
|
+
|
3
|
+
import threading
|
4
|
+
from uuid import UUID
|
5
|
+
from typing import Dict, Optional
|
6
|
+
|
7
|
+
from vellum.workflows.context import ExecutionContext
|
8
|
+
|
9
|
+
DEFAULT_TRACE_ID = UUID("00000000-0000-0000-0000-000000000000")
|
10
|
+
|
11
|
+
# Thread-local storage for monitoring execution context
|
12
|
+
_monitoring_execution_context: threading.local = threading.local()
|
13
|
+
# Thread-local storage for current span_id
|
14
|
+
_current_span_id: threading.local = threading.local()
|
15
|
+
|
16
|
+
|
17
|
+
class _MonitoringContextStore:
|
18
|
+
"""
|
19
|
+
thread-safe storage for monitoring contexts.
|
20
|
+
handles context persistence and retrieval across threads.
|
21
|
+
relies on the execution context manager for manual retrieval
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self):
|
25
|
+
self._lock = threading.Lock()
|
26
|
+
self._contexts: Dict[str, ExecutionContext] = {}
|
27
|
+
self._thread_contexts: Dict[int, ExecutionContext] = {}
|
28
|
+
self._current_trace_id: Optional[UUID] = None
|
29
|
+
|
30
|
+
def set_current_trace_id(self, trace_id: UUID) -> None:
|
31
|
+
"""Set the current active trace_id that should be used by all threads."""
|
32
|
+
if trace_id != DEFAULT_TRACE_ID:
|
33
|
+
with self._lock:
|
34
|
+
self._current_trace_id = trace_id
|
35
|
+
|
36
|
+
def get_current_trace_id(self) -> Optional[UUID]:
|
37
|
+
"""Get the current active trace_id that should be used by all threads."""
|
38
|
+
with self._lock:
|
39
|
+
return self._current_trace_id
|
40
|
+
|
41
|
+
def set_current_span_id(self, span_id: UUID) -> None:
|
42
|
+
"""Set the current active span_id for this thread."""
|
43
|
+
_current_span_id.span_id = span_id
|
44
|
+
|
45
|
+
def get_current_span_id(self) -> Optional[UUID]:
|
46
|
+
"""Get the current active span_id for this thread."""
|
47
|
+
return getattr(_current_span_id, "span_id", None)
|
48
|
+
|
49
|
+
def store_context(self, context: Optional[ExecutionContext]) -> None:
|
50
|
+
"""Store monitoring parent context using multiple keys for reliable retrieval."""
|
51
|
+
if not context or context.parent_context is None:
|
52
|
+
return
|
53
|
+
|
54
|
+
thread_id = threading.get_ident()
|
55
|
+
trace_id = self.get_current_trace_id()
|
56
|
+
if context.trace_id != DEFAULT_TRACE_ID and trace_id is None:
|
57
|
+
self.set_current_trace_id(context.trace_id)
|
58
|
+
|
59
|
+
with self._lock:
|
60
|
+
# Use trace:span:thread for unique context storage
|
61
|
+
trace_span_thread_key = (
|
62
|
+
f"trace:{str(trace_id)}:span:{str(context.parent_context.span_id)}:thread:{thread_id}"
|
63
|
+
)
|
64
|
+
self._contexts[trace_span_thread_key] = context
|
65
|
+
|
66
|
+
def retrieve_context(self, trace_id: UUID, span_id: Optional[UUID] = None) -> Optional[ExecutionContext]:
|
67
|
+
"""Retrieve monitoring parent context with multiple fallback strategies."""
|
68
|
+
thread_id = threading.get_ident()
|
69
|
+
with self._lock:
|
70
|
+
if not span_id:
|
71
|
+
span_id = getattr(_current_span_id, "span_id", None)
|
72
|
+
if not span_id:
|
73
|
+
return None
|
74
|
+
|
75
|
+
span_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{thread_id}"
|
76
|
+
if span_key in self._contexts:
|
77
|
+
result = self._contexts[span_key]
|
78
|
+
return result
|
79
|
+
|
80
|
+
return None
|
81
|
+
|
82
|
+
|
83
|
+
# Global instance for cross-boundary context persistence
|
84
|
+
_monitoring_context_store = _MonitoringContextStore()
|
85
|
+
|
86
|
+
|
87
|
+
def get_monitoring_execution_context() -> ExecutionContext:
|
88
|
+
"""Get the current monitoring execution context, with intelligent fallback."""
|
89
|
+
if hasattr(_monitoring_execution_context, "context"):
|
90
|
+
context = _monitoring_execution_context.context
|
91
|
+
if context.trace_id != DEFAULT_TRACE_ID and context.parent_context:
|
92
|
+
return context
|
93
|
+
|
94
|
+
# If no thread-local context, try to restore from global store using current trace_id
|
95
|
+
trace_id = _monitoring_context_store.get_current_trace_id()
|
96
|
+
span_id = _current_span_id.span_id if hasattr(_current_span_id, "span_id") else None
|
97
|
+
if trace_id:
|
98
|
+
if trace_id != DEFAULT_TRACE_ID:
|
99
|
+
context = _monitoring_context_store.retrieve_context(trace_id, span_id)
|
100
|
+
if context:
|
101
|
+
_monitoring_execution_context.context = context
|
102
|
+
return context
|
103
|
+
return ExecutionContext()
|
104
|
+
|
105
|
+
|
106
|
+
def set_monitoring_execution_context(context: ExecutionContext) -> None:
|
107
|
+
"""Set the current monitoring execution context and persist it for cross-boundary access."""
|
108
|
+
_monitoring_execution_context.context = context
|
109
|
+
|
110
|
+
if context.trace_id and context.parent_context:
|
111
|
+
_monitoring_context_store.store_context(context)
|
File without changes
|
@@ -0,0 +1,138 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
|
+
|
6
|
+
import requests
|
7
|
+
|
8
|
+
from vellum.workflows.exceptions import NodeException
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class ConnectionInfo:
|
15
|
+
"""Information about a user's authorized connection"""
|
16
|
+
|
17
|
+
connection_id: str
|
18
|
+
integration_name: str
|
19
|
+
created_at: str
|
20
|
+
updated_at: str
|
21
|
+
status: str = "ACTIVE"
|
22
|
+
|
23
|
+
|
24
|
+
class ComposioService:
|
25
|
+
"""Composio API client for managing connections and executing tools"""
|
26
|
+
|
27
|
+
def __init__(self, api_key: Optional[str] = None):
|
28
|
+
# If no API key provided, look it up from environment variables
|
29
|
+
if api_key is None:
|
30
|
+
api_key = self._get_api_key_from_env()
|
31
|
+
|
32
|
+
if not api_key:
|
33
|
+
common_env_var_names = ["COMPOSIO_API_KEY", "COMPOSIO_KEY"]
|
34
|
+
raise NodeException(
|
35
|
+
"No Composio API key found. "
|
36
|
+
"Please provide an api_key parameter or set one of these environment variables: "
|
37
|
+
+ ", ".join(common_env_var_names)
|
38
|
+
)
|
39
|
+
|
40
|
+
self.api_key = api_key
|
41
|
+
self.base_url = "https://backend.composio.dev/api/v3"
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def _get_api_key_from_env() -> Optional[str]:
|
45
|
+
"""Get Composio API key from environment variables"""
|
46
|
+
common_env_var_names = ["COMPOSIO_API_KEY", "COMPOSIO_KEY"]
|
47
|
+
|
48
|
+
for env_var_name in common_env_var_names:
|
49
|
+
value = os.environ.get(env_var_name)
|
50
|
+
if value:
|
51
|
+
return value
|
52
|
+
return None
|
53
|
+
|
54
|
+
def _make_request(
|
55
|
+
self, endpoint: str, method: str = "GET", params: Optional[dict] = None, json_data: Optional[dict] = None
|
56
|
+
) -> dict:
|
57
|
+
"""Make a request to the Composio API"""
|
58
|
+
headers = {
|
59
|
+
"x-api-key": self.api_key,
|
60
|
+
"Content-Type": "application/json",
|
61
|
+
}
|
62
|
+
|
63
|
+
url = f"{self.base_url}{endpoint}"
|
64
|
+
|
65
|
+
try:
|
66
|
+
if method == "GET":
|
67
|
+
response = requests.get(url, headers=headers, params=params or {}, timeout=30)
|
68
|
+
elif method == "POST":
|
69
|
+
response = requests.post(url, headers=headers, json=json_data or {}, timeout=30)
|
70
|
+
else:
|
71
|
+
raise ValueError(f"Unsupported HTTP method: {method}")
|
72
|
+
|
73
|
+
response.raise_for_status()
|
74
|
+
return response.json()
|
75
|
+
except Exception as e:
|
76
|
+
raise NodeException(f"Composio API request failed: {e}")
|
77
|
+
|
78
|
+
def get_user_connections(self) -> List[ConnectionInfo]:
|
79
|
+
"""Get all authorized connections for the user"""
|
80
|
+
response = self._make_request("/connected_accounts")
|
81
|
+
|
82
|
+
return [
|
83
|
+
ConnectionInfo(
|
84
|
+
connection_id=item.get("id"),
|
85
|
+
integration_name=item.get("toolkit", {}).get("slug", ""),
|
86
|
+
status=item.get("status", "ACTIVE"),
|
87
|
+
created_at=item.get("created_at", ""),
|
88
|
+
updated_at=item.get("updated_at", ""),
|
89
|
+
)
|
90
|
+
for item in response.get("items", [])
|
91
|
+
]
|
92
|
+
|
93
|
+
def get_tool_by_slug(self, tool_slug: str) -> Dict[str, Any]:
|
94
|
+
"""Get detailed information about a tool using its slug identifier
|
95
|
+
|
96
|
+
Args:
|
97
|
+
tool_slug: The unique slug identifier of the tool
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Dictionary containing detailed tool information including:
|
101
|
+
- slug, name, description
|
102
|
+
- toolkit info (slug, name, logo)
|
103
|
+
- input_parameters, output_parameters
|
104
|
+
- no_auth, available_versions, version
|
105
|
+
- scopes, tags, deprecated info
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
NodeException: If tool not found (404), unauthorized (401), or other API errors
|
109
|
+
"""
|
110
|
+
endpoint = f"/tools/{tool_slug}"
|
111
|
+
|
112
|
+
try:
|
113
|
+
response = self._make_request(endpoint, method="GET")
|
114
|
+
logger.info(f"Retrieved tool details for slug '{tool_slug}': {response}")
|
115
|
+
return response
|
116
|
+
except Exception as e:
|
117
|
+
# Enhanced error handling for specific cases
|
118
|
+
error_message = str(e)
|
119
|
+
if "404" in error_message:
|
120
|
+
raise NodeException(f"Tool with slug '{tool_slug}' not found in Composio")
|
121
|
+
elif "401" in error_message:
|
122
|
+
raise NodeException(f"Unauthorized access to tool '{tool_slug}'. Check your Composio API key.")
|
123
|
+
else:
|
124
|
+
raise NodeException(f"Failed to retrieve tool details for '{tool_slug}': {error_message}")
|
125
|
+
|
126
|
+
def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
127
|
+
"""Execute a tool using direct API request
|
128
|
+
|
129
|
+
Args:
|
130
|
+
tool_name: The name of the tool to execute (e.g., "HACKERNEWS_GET_USER")
|
131
|
+
arguments: Dictionary of arguments to pass to the tool
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
The result of the tool execution
|
135
|
+
"""
|
136
|
+
endpoint = f"/tools/execute/{tool_name}"
|
137
|
+
response = self._make_request(endpoint, method="POST", json_data={"arguments": arguments})
|
138
|
+
return response.get("data", response)
|
@@ -5,6 +5,7 @@ from requests.exceptions import JSONDecodeError
|
|
5
5
|
|
6
6
|
from vellum.client import ApiError
|
7
7
|
from vellum.client.core.request_options import RequestOptions
|
8
|
+
from vellum.client.types.method_enum import MethodEnum
|
8
9
|
from vellum.client.types.vellum_secret import VellumSecret as ClientVellumSecret
|
9
10
|
from vellum.workflows.constants import APIRequestMethod
|
10
11
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -20,7 +21,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
|
|
20
21
|
Used to execute an API call.
|
21
22
|
|
22
23
|
url: str - The URL to send the request to.
|
23
|
-
method: APIRequestMethod - The HTTP method to use for the request.
|
24
|
+
method: Union[APIRequestMethod, str] - The HTTP method to use for the request.
|
24
25
|
data: Optional[str] - The data to send in the request body.
|
25
26
|
json: Optional["JsonObject"] - The JSON data to send in the request body.
|
26
27
|
headers: Optional[Dict[str, Union[str, VellumSecret]]] - The headers to send in the request.
|
@@ -31,7 +32,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
|
|
31
32
|
merge_behavior = MergeBehavior.AWAIT_ANY
|
32
33
|
|
33
34
|
url: str = ""
|
34
|
-
method: Optional[APIRequestMethod] = APIRequestMethod.GET
|
35
|
+
method: Optional[Union[APIRequestMethod, MethodEnum]] = APIRequestMethod.GET
|
35
36
|
data: Optional[str] = None
|
36
37
|
json: Optional[Json] = None
|
37
38
|
headers: Optional[Dict[str, Union[str, VellumSecret]]] = None
|
@@ -43,6 +44,21 @@ class BaseAPINode(BaseNode, Generic[StateType]):
|
|
43
44
|
status_code: int
|
44
45
|
text: str
|
45
46
|
|
47
|
+
def _normalize_http_method(self, method: Union[APIRequestMethod, MethodEnum]) -> str:
|
48
|
+
if isinstance(method, APIRequestMethod):
|
49
|
+
method_str = method.value
|
50
|
+
elif isinstance(method, str):
|
51
|
+
method_str = method.upper()
|
52
|
+
valid_methods = {m.value for m in APIRequestMethod}
|
53
|
+
if method_str not in valid_methods:
|
54
|
+
raise NodeException(f"Invalid HTTP method '{method}'", code=WorkflowErrorCode.INVALID_INPUTS)
|
55
|
+
else:
|
56
|
+
raise NodeException(
|
57
|
+
f"Method must be either APIRequestMethod enum or string, got {type(method)}",
|
58
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
59
|
+
)
|
60
|
+
return method_str
|
61
|
+
|
46
62
|
def _validate(self) -> None:
|
47
63
|
if not self.url or not isinstance(self.url, str) or not self.url.strip():
|
48
64
|
raise NodeException("URL is required and must be a non-empty string", code=WorkflowErrorCode.INVALID_INPUTS)
|
@@ -55,7 +71,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
|
|
55
71
|
def _run(
|
56
72
|
self,
|
57
73
|
url: str,
|
58
|
-
method: Optional[APIRequestMethod] = APIRequestMethod.GET,
|
74
|
+
method: Optional[Union[APIRequestMethod, MethodEnum]] = APIRequestMethod.GET,
|
59
75
|
data: Optional[Union[str, Any]] = None,
|
60
76
|
json: Any = None,
|
61
77
|
headers: Any = None,
|
@@ -64,23 +80,25 @@ class BaseAPINode(BaseNode, Generic[StateType]):
|
|
64
80
|
) -> Outputs:
|
65
81
|
self._validate()
|
66
82
|
|
83
|
+
normalized_method = self._normalize_http_method(method) if method is not None else APIRequestMethod.GET.value
|
84
|
+
|
67
85
|
vellum_instance = False
|
68
86
|
for header in headers or {}:
|
69
87
|
if isinstance(headers[header], VellumSecret):
|
70
88
|
vellum_instance = True
|
71
89
|
if vellum_instance or bearer_token:
|
72
|
-
return self._vellum_execute_api(bearer_token, json, headers,
|
90
|
+
return self._vellum_execute_api(bearer_token, json, headers, normalized_method, url, timeout)
|
73
91
|
else:
|
74
|
-
return self._local_execute_api(data, headers, json,
|
92
|
+
return self._local_execute_api(data, headers, json, normalized_method, url, timeout)
|
75
93
|
|
76
94
|
def _local_execute_api(self, data, headers, json, method, url, timeout):
|
77
95
|
try:
|
78
96
|
if data is not None:
|
79
|
-
prepped = Request(method=method
|
97
|
+
prepped = Request(method=method, url=url, data=data, headers=headers).prepare()
|
80
98
|
elif json is not None:
|
81
|
-
prepped = Request(method=method
|
99
|
+
prepped = Request(method=method, url=url, json=json, headers=headers).prepare()
|
82
100
|
else:
|
83
|
-
prepped = Request(method=method
|
101
|
+
prepped = Request(method=method, url=url, headers=headers).prepare()
|
84
102
|
except Exception as e:
|
85
103
|
raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.PROVIDER_ERROR)
|
86
104
|
try:
|
@@ -110,7 +128,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
|
|
110
128
|
try:
|
111
129
|
vellum_response = self._context.vellum_client.execute_api(
|
112
130
|
url=url,
|
113
|
-
method=method
|
131
|
+
method=method,
|
114
132
|
body=data,
|
115
133
|
headers=headers,
|
116
134
|
bearer_token=client_vellum_secret,
|
File without changes
|
@@ -0,0 +1,47 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.client.types.execute_api_response import ExecuteApiResponse
|
4
|
+
from vellum.workflows.constants import APIRequestMethod
|
5
|
+
from vellum.workflows.errors.types import WorkflowErrorCode
|
6
|
+
from vellum.workflows.exceptions import NodeException
|
7
|
+
from vellum.workflows.nodes.displayable.bases.api_node.node import BaseAPINode
|
8
|
+
from vellum.workflows.types.core import VellumSecret
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.mark.parametrize("method_value", ["GET", "get", APIRequestMethod.GET])
|
12
|
+
def test_api_node_with_string_method(method_value, vellum_client):
|
13
|
+
class TestAPINode(BaseAPINode):
|
14
|
+
method = method_value
|
15
|
+
url = "https://example.com"
|
16
|
+
headers = {"Authorization": VellumSecret(name="API_KEY")}
|
17
|
+
|
18
|
+
mock_response = ExecuteApiResponse(
|
19
|
+
json_={"status": "success"},
|
20
|
+
headers={"content-type": "application/json"},
|
21
|
+
status_code=200,
|
22
|
+
text='{"status": "success"}',
|
23
|
+
)
|
24
|
+
vellum_client.execute_api.return_value = mock_response
|
25
|
+
|
26
|
+
node = TestAPINode()
|
27
|
+
result = node.run()
|
28
|
+
|
29
|
+
assert result.status_code == 200
|
30
|
+
|
31
|
+
vellum_client.execute_api.assert_called_once()
|
32
|
+
call_args = vellum_client.execute_api.call_args
|
33
|
+
assert call_args[1]["method"] == "GET"
|
34
|
+
|
35
|
+
|
36
|
+
def test_api_node_with_invalid_method():
|
37
|
+
class TestAPINode(BaseAPINode):
|
38
|
+
method = "INVALID_METHOD"
|
39
|
+
url = "https://example.com"
|
40
|
+
|
41
|
+
node = TestAPINode()
|
42
|
+
|
43
|
+
with pytest.raises(NodeException) as exc_info:
|
44
|
+
node.run()
|
45
|
+
|
46
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
|
47
|
+
assert "Invalid HTTP method 'INVALID_METHOD'" == str(exc_info.value)
|