vellum-ai 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. vellum/__init__.py +0 -6
  2. vellum/client/core/client_wrapper.py +2 -2
  3. vellum/client/types/__init__.py +0 -6
  4. vellum/client/types/organization_read.py +1 -2
  5. vellum/workflows/events/context.py +111 -0
  6. vellum/workflows/integrations/__init__.py +0 -0
  7. vellum/workflows/integrations/composio_service.py +138 -0
  8. vellum/workflows/nodes/displayable/bases/api_node/node.py +27 -9
  9. vellum/workflows/nodes/displayable/bases/api_node/tests/__init__.py +0 -0
  10. vellum/workflows/nodes/displayable/bases/api_node/tests/test_node.py +47 -0
  11. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_composio_service.py +63 -58
  12. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +21 -1
  13. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +124 -59
  14. vellum/workflows/types/definition.py +4 -2
  15. vellum/workflows/utils/functions.py +13 -1
  16. vellum/workflows/utils/tests/test_functions.py +32 -1
  17. {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/METADATA +1 -3
  18. {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/RECORD +26 -27
  19. vellum_cli/push.py +11 -2
  20. vellum_cli/tests/test_push.py +57 -1
  21. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -0
  22. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +16 -0
  23. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_composio_serialization.py +89 -0
  24. vellum/client/types/organization_limit_config.py +0 -25
  25. vellum/client/types/quota.py +0 -21
  26. vellum/client/types/vembda_service_tier_enum.py +0 -5
  27. vellum/types/organization_limit_config.py +0 -3
  28. vellum/types/quota.py +0 -3
  29. vellum/types/vembda_service_tier_enum.py +0 -3
  30. vellum/workflows/nodes/displayable/tool_calling_node/composio_service.py +0 -83
  31. {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/LICENSE +0 -0
  32. {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/WHEEL +0 -0
  33. {vellum_ai-1.0.4.dist-info → vellum_ai-1.0.6.dist-info}/entry_points.txt +0 -0
vellum/__init__.py CHANGED
@@ -288,7 +288,6 @@ from .client.types import (
288
288
  OpenAiVectorizerTextEmbedding3SmallRequest,
289
289
  OpenAiVectorizerTextEmbeddingAda002,
290
290
  OpenAiVectorizerTextEmbeddingAda002Request,
291
- OrganizationLimitConfig,
292
291
  OrganizationRead,
293
292
  PaginatedContainerImageReadList,
294
293
  PaginatedDeploymentReleaseTagReadList,
@@ -329,7 +328,6 @@ from .client.types import (
329
328
  PromptRequestStringInput,
330
329
  PromptSettings,
331
330
  PromptVersionBuildConfigSandbox,
332
- Quota,
333
331
  RawPromptExecutionOverridesRequest,
334
332
  ReductoChunkerConfig,
335
333
  ReductoChunkerConfigRequest,
@@ -527,7 +525,6 @@ from .client.types import (
527
525
  VellumVariableExtensions,
528
526
  VellumVariableType,
529
527
  VellumWorkflowExecutionEvent,
530
- VembdaServiceTierEnum,
531
528
  WorkflowDeploymentEventExecutionsResponse,
532
529
  WorkflowDeploymentHistoryItem,
533
530
  WorkflowDeploymentParentContext,
@@ -939,7 +936,6 @@ __all__ = [
939
936
  "OpenAiVectorizerTextEmbedding3SmallRequest",
940
937
  "OpenAiVectorizerTextEmbeddingAda002",
941
938
  "OpenAiVectorizerTextEmbeddingAda002Request",
942
- "OrganizationLimitConfig",
943
939
  "OrganizationRead",
944
940
  "PaginatedContainerImageReadList",
945
941
  "PaginatedDeploymentReleaseTagReadList",
@@ -980,7 +976,6 @@ __all__ = [
980
976
  "PromptRequestStringInput",
981
977
  "PromptSettings",
982
978
  "PromptVersionBuildConfigSandbox",
983
- "Quota",
984
979
  "RawPromptExecutionOverridesRequest",
985
980
  "ReductoChunkerConfig",
986
981
  "ReductoChunkerConfigRequest",
@@ -1180,7 +1175,6 @@ __all__ = [
1180
1175
  "VellumVariableExtensions",
1181
1176
  "VellumVariableType",
1182
1177
  "VellumWorkflowExecutionEvent",
1183
- "VembdaServiceTierEnum",
1184
1178
  "WorkflowDeploymentEventExecutionsResponse",
1185
1179
  "WorkflowDeploymentHistoryItem",
1186
1180
  "WorkflowDeploymentParentContext",
@@ -25,10 +25,10 @@ class BaseClientWrapper:
25
25
 
26
26
  def get_headers(self) -> typing.Dict[str, str]:
27
27
  headers: typing.Dict[str, str] = {
28
- "User-Agent": "vellum-ai/1.0.4",
28
+ "User-Agent": "vellum-ai/1.0.6",
29
29
  "X-Fern-Language": "Python",
30
30
  "X-Fern-SDK-Name": "vellum-ai",
31
- "X-Fern-SDK-Version": "1.0.4",
31
+ "X-Fern-SDK-Version": "1.0.6",
32
32
  }
33
33
  if self._api_version is not None:
34
34
  headers["X-API-Version"] = self._api_version
@@ -296,7 +296,6 @@ from .open_ai_vectorizer_text_embedding_3_small import OpenAiVectorizerTextEmbed
296
296
  from .open_ai_vectorizer_text_embedding_3_small_request import OpenAiVectorizerTextEmbedding3SmallRequest
297
297
  from .open_ai_vectorizer_text_embedding_ada_002 import OpenAiVectorizerTextEmbeddingAda002
298
298
  from .open_ai_vectorizer_text_embedding_ada_002_request import OpenAiVectorizerTextEmbeddingAda002Request
299
- from .organization_limit_config import OrganizationLimitConfig
300
299
  from .organization_read import OrganizationRead
301
300
  from .paginated_container_image_read_list import PaginatedContainerImageReadList
302
301
  from .paginated_deployment_release_tag_read_list import PaginatedDeploymentReleaseTagReadList
@@ -337,7 +336,6 @@ from .prompt_request_json_input import PromptRequestJsonInput
337
336
  from .prompt_request_string_input import PromptRequestStringInput
338
337
  from .prompt_settings import PromptSettings
339
338
  from .prompt_version_build_config_sandbox import PromptVersionBuildConfigSandbox
340
- from .quota import Quota
341
339
  from .raw_prompt_execution_overrides_request import RawPromptExecutionOverridesRequest
342
340
  from .reducto_chunker_config import ReductoChunkerConfig
343
341
  from .reducto_chunker_config_request import ReductoChunkerConfigRequest
@@ -551,7 +549,6 @@ from .vellum_variable import VellumVariable
551
549
  from .vellum_variable_extensions import VellumVariableExtensions
552
550
  from .vellum_variable_type import VellumVariableType
553
551
  from .vellum_workflow_execution_event import VellumWorkflowExecutionEvent
554
- from .vembda_service_tier_enum import VembdaServiceTierEnum
555
552
  from .workflow_deployment_event_executions_response import WorkflowDeploymentEventExecutionsResponse
556
553
  from .workflow_deployment_history_item import WorkflowDeploymentHistoryItem
557
554
  from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
@@ -918,7 +915,6 @@ __all__ = [
918
915
  "OpenAiVectorizerTextEmbedding3SmallRequest",
919
916
  "OpenAiVectorizerTextEmbeddingAda002",
920
917
  "OpenAiVectorizerTextEmbeddingAda002Request",
921
- "OrganizationLimitConfig",
922
918
  "OrganizationRead",
923
919
  "PaginatedContainerImageReadList",
924
920
  "PaginatedDeploymentReleaseTagReadList",
@@ -959,7 +955,6 @@ __all__ = [
959
955
  "PromptRequestStringInput",
960
956
  "PromptSettings",
961
957
  "PromptVersionBuildConfigSandbox",
962
- "Quota",
963
958
  "RawPromptExecutionOverridesRequest",
964
959
  "ReductoChunkerConfig",
965
960
  "ReductoChunkerConfigRequest",
@@ -1157,7 +1152,6 @@ __all__ = [
1157
1152
  "VellumVariableExtensions",
1158
1153
  "VellumVariableType",
1159
1154
  "VellumWorkflowExecutionEvent",
1160
- "VembdaServiceTierEnum",
1161
1155
  "WorkflowDeploymentEventExecutionsResponse",
1162
1156
  "WorkflowDeploymentHistoryItem",
1163
1157
  "WorkflowDeploymentParentContext",
@@ -3,7 +3,6 @@
3
3
  from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .new_member_join_behavior_enum import NewMemberJoinBehaviorEnum
6
- from .organization_limit_config import OrganizationLimitConfig
7
6
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
7
  import pydantic
9
8
 
@@ -13,7 +12,7 @@ class OrganizationRead(UniversalBaseModel):
13
12
  name: str
14
13
  allow_staff_access: typing.Optional[bool] = None
15
14
  new_member_join_behavior: NewMemberJoinBehaviorEnum
16
- limit_config: OrganizationLimitConfig
15
+ limit_config: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = None
17
16
 
18
17
  if IS_PYDANTIC_V2:
19
18
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -0,0 +1,111 @@
1
+ """Monitoring execution context for workflow tracing."""
2
+
3
+ import threading
4
+ from uuid import UUID
5
+ from typing import Dict, Optional
6
+
7
+ from vellum.workflows.context import ExecutionContext
8
+
9
+ DEFAULT_TRACE_ID = UUID("00000000-0000-0000-0000-000000000000")
10
+
11
+ # Thread-local storage for monitoring execution context
12
+ _monitoring_execution_context: threading.local = threading.local()
13
+ # Thread-local storage for current span_id
14
+ _current_span_id: threading.local = threading.local()
15
+
16
+
17
+ class _MonitoringContextStore:
18
+ """
19
+ thread-safe storage for monitoring contexts.
20
+ handles context persistence and retrieval across threads.
21
+ relies on the execution context manager for manual retrieval
22
+ """
23
+
24
+ def __init__(self):
25
+ self._lock = threading.Lock()
26
+ self._contexts: Dict[str, ExecutionContext] = {}
27
+ self._thread_contexts: Dict[int, ExecutionContext] = {}
28
+ self._current_trace_id: Optional[UUID] = None
29
+
30
+ def set_current_trace_id(self, trace_id: UUID) -> None:
31
+ """Set the current active trace_id that should be used by all threads."""
32
+ if trace_id != DEFAULT_TRACE_ID:
33
+ with self._lock:
34
+ self._current_trace_id = trace_id
35
+
36
+ def get_current_trace_id(self) -> Optional[UUID]:
37
+ """Get the current active trace_id that should be used by all threads."""
38
+ with self._lock:
39
+ return self._current_trace_id
40
+
41
+ def set_current_span_id(self, span_id: UUID) -> None:
42
+ """Set the current active span_id for this thread."""
43
+ _current_span_id.span_id = span_id
44
+
45
+ def get_current_span_id(self) -> Optional[UUID]:
46
+ """Get the current active span_id for this thread."""
47
+ return getattr(_current_span_id, "span_id", None)
48
+
49
+ def store_context(self, context: Optional[ExecutionContext]) -> None:
50
+ """Store monitoring parent context using multiple keys for reliable retrieval."""
51
+ if not context or context.parent_context is None:
52
+ return
53
+
54
+ thread_id = threading.get_ident()
55
+ trace_id = self.get_current_trace_id()
56
+ if context.trace_id != DEFAULT_TRACE_ID and trace_id is None:
57
+ self.set_current_trace_id(context.trace_id)
58
+
59
+ with self._lock:
60
+ # Use trace:span:thread for unique context storage
61
+ trace_span_thread_key = (
62
+ f"trace:{str(trace_id)}:span:{str(context.parent_context.span_id)}:thread:{thread_id}"
63
+ )
64
+ self._contexts[trace_span_thread_key] = context
65
+
66
+ def retrieve_context(self, trace_id: UUID, span_id: Optional[UUID] = None) -> Optional[ExecutionContext]:
67
+ """Retrieve monitoring parent context with multiple fallback strategies."""
68
+ thread_id = threading.get_ident()
69
+ with self._lock:
70
+ if not span_id:
71
+ span_id = getattr(_current_span_id, "span_id", None)
72
+ if not span_id:
73
+ return None
74
+
75
+ span_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{thread_id}"
76
+ if span_key in self._contexts:
77
+ result = self._contexts[span_key]
78
+ return result
79
+
80
+ return None
81
+
82
+
83
+ # Global instance for cross-boundary context persistence
84
+ _monitoring_context_store = _MonitoringContextStore()
85
+
86
+
87
+ def get_monitoring_execution_context() -> ExecutionContext:
88
+ """Get the current monitoring execution context, with intelligent fallback."""
89
+ if hasattr(_monitoring_execution_context, "context"):
90
+ context = _monitoring_execution_context.context
91
+ if context.trace_id != DEFAULT_TRACE_ID and context.parent_context:
92
+ return context
93
+
94
+ # If no thread-local context, try to restore from global store using current trace_id
95
+ trace_id = _monitoring_context_store.get_current_trace_id()
96
+ span_id = _current_span_id.span_id if hasattr(_current_span_id, "span_id") else None
97
+ if trace_id:
98
+ if trace_id != DEFAULT_TRACE_ID:
99
+ context = _monitoring_context_store.retrieve_context(trace_id, span_id)
100
+ if context:
101
+ _monitoring_execution_context.context = context
102
+ return context
103
+ return ExecutionContext()
104
+
105
+
106
+ def set_monitoring_execution_context(context: ExecutionContext) -> None:
107
+ """Set the current monitoring execution context and persist it for cross-boundary access."""
108
+ _monitoring_execution_context.context = context
109
+
110
+ if context.trace_id and context.parent_context:
111
+ _monitoring_context_store.store_context(context)
File without changes
@@ -0,0 +1,138 @@
1
+ from dataclasses import dataclass
2
+ import logging
3
+ import os
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import requests
7
+
8
+ from vellum.workflows.exceptions import NodeException
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class ConnectionInfo:
15
+ """Information about a user's authorized connection"""
16
+
17
+ connection_id: str
18
+ integration_name: str
19
+ created_at: str
20
+ updated_at: str
21
+ status: str = "ACTIVE"
22
+
23
+
24
+ class ComposioService:
25
+ """Composio API client for managing connections and executing tools"""
26
+
27
+ def __init__(self, api_key: Optional[str] = None):
28
+ # If no API key provided, look it up from environment variables
29
+ if api_key is None:
30
+ api_key = self._get_api_key_from_env()
31
+
32
+ if not api_key:
33
+ common_env_var_names = ["COMPOSIO_API_KEY", "COMPOSIO_KEY"]
34
+ raise NodeException(
35
+ "No Composio API key found. "
36
+ "Please provide an api_key parameter or set one of these environment variables: "
37
+ + ", ".join(common_env_var_names)
38
+ )
39
+
40
+ self.api_key = api_key
41
+ self.base_url = "https://backend.composio.dev/api/v3"
42
+
43
+ @staticmethod
44
+ def _get_api_key_from_env() -> Optional[str]:
45
+ """Get Composio API key from environment variables"""
46
+ common_env_var_names = ["COMPOSIO_API_KEY", "COMPOSIO_KEY"]
47
+
48
+ for env_var_name in common_env_var_names:
49
+ value = os.environ.get(env_var_name)
50
+ if value:
51
+ return value
52
+ return None
53
+
54
+ def _make_request(
55
+ self, endpoint: str, method: str = "GET", params: Optional[dict] = None, json_data: Optional[dict] = None
56
+ ) -> dict:
57
+ """Make a request to the Composio API"""
58
+ headers = {
59
+ "x-api-key": self.api_key,
60
+ "Content-Type": "application/json",
61
+ }
62
+
63
+ url = f"{self.base_url}{endpoint}"
64
+
65
+ try:
66
+ if method == "GET":
67
+ response = requests.get(url, headers=headers, params=params or {}, timeout=30)
68
+ elif method == "POST":
69
+ response = requests.post(url, headers=headers, json=json_data or {}, timeout=30)
70
+ else:
71
+ raise ValueError(f"Unsupported HTTP method: {method}")
72
+
73
+ response.raise_for_status()
74
+ return response.json()
75
+ except Exception as e:
76
+ raise NodeException(f"Composio API request failed: {e}")
77
+
78
+ def get_user_connections(self) -> List[ConnectionInfo]:
79
+ """Get all authorized connections for the user"""
80
+ response = self._make_request("/connected_accounts")
81
+
82
+ return [
83
+ ConnectionInfo(
84
+ connection_id=item.get("id"),
85
+ integration_name=item.get("toolkit", {}).get("slug", ""),
86
+ status=item.get("status", "ACTIVE"),
87
+ created_at=item.get("created_at", ""),
88
+ updated_at=item.get("updated_at", ""),
89
+ )
90
+ for item in response.get("items", [])
91
+ ]
92
+
93
+ def get_tool_by_slug(self, tool_slug: str) -> Dict[str, Any]:
94
+ """Get detailed information about a tool using its slug identifier
95
+
96
+ Args:
97
+ tool_slug: The unique slug identifier of the tool
98
+
99
+ Returns:
100
+ Dictionary containing detailed tool information including:
101
+ - slug, name, description
102
+ - toolkit info (slug, name, logo)
103
+ - input_parameters, output_parameters
104
+ - no_auth, available_versions, version
105
+ - scopes, tags, deprecated info
106
+
107
+ Raises:
108
+ NodeException: If tool not found (404), unauthorized (401), or other API errors
109
+ """
110
+ endpoint = f"/tools/{tool_slug}"
111
+
112
+ try:
113
+ response = self._make_request(endpoint, method="GET")
114
+ logger.info(f"Retrieved tool details for slug '{tool_slug}': {response}")
115
+ return response
116
+ except Exception as e:
117
+ # Enhanced error handling for specific cases
118
+ error_message = str(e)
119
+ if "404" in error_message:
120
+ raise NodeException(f"Tool with slug '{tool_slug}' not found in Composio")
121
+ elif "401" in error_message:
122
+ raise NodeException(f"Unauthorized access to tool '{tool_slug}'. Check your Composio API key.")
123
+ else:
124
+ raise NodeException(f"Failed to retrieve tool details for '{tool_slug}': {error_message}")
125
+
126
+ def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
127
+ """Execute a tool using direct API request
128
+
129
+ Args:
130
+ tool_name: The name of the tool to execute (e.g., "HACKERNEWS_GET_USER")
131
+ arguments: Dictionary of arguments to pass to the tool
132
+
133
+ Returns:
134
+ The result of the tool execution
135
+ """
136
+ endpoint = f"/tools/execute/{tool_name}"
137
+ response = self._make_request(endpoint, method="POST", json_data={"arguments": arguments})
138
+ return response.get("data", response)
@@ -5,6 +5,7 @@ from requests.exceptions import JSONDecodeError
5
5
 
6
6
  from vellum.client import ApiError
7
7
  from vellum.client.core.request_options import RequestOptions
8
+ from vellum.client.types.method_enum import MethodEnum
8
9
  from vellum.client.types.vellum_secret import VellumSecret as ClientVellumSecret
9
10
  from vellum.workflows.constants import APIRequestMethod
10
11
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -20,7 +21,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
20
21
  Used to execute an API call.
21
22
 
22
23
  url: str - The URL to send the request to.
23
- method: APIRequestMethod - The HTTP method to use for the request.
24
+ method: Union[APIRequestMethod, str] - The HTTP method to use for the request.
24
25
  data: Optional[str] - The data to send in the request body.
25
26
  json: Optional["JsonObject"] - The JSON data to send in the request body.
26
27
  headers: Optional[Dict[str, Union[str, VellumSecret]]] - The headers to send in the request.
@@ -31,7 +32,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
31
32
  merge_behavior = MergeBehavior.AWAIT_ANY
32
33
 
33
34
  url: str = ""
34
- method: Optional[APIRequestMethod] = APIRequestMethod.GET
35
+ method: Optional[Union[APIRequestMethod, MethodEnum]] = APIRequestMethod.GET
35
36
  data: Optional[str] = None
36
37
  json: Optional[Json] = None
37
38
  headers: Optional[Dict[str, Union[str, VellumSecret]]] = None
@@ -43,6 +44,21 @@ class BaseAPINode(BaseNode, Generic[StateType]):
43
44
  status_code: int
44
45
  text: str
45
46
 
47
+ def _normalize_http_method(self, method: Union[APIRequestMethod, MethodEnum]) -> str:
48
+ if isinstance(method, APIRequestMethod):
49
+ method_str = method.value
50
+ elif isinstance(method, str):
51
+ method_str = method.upper()
52
+ valid_methods = {m.value for m in APIRequestMethod}
53
+ if method_str not in valid_methods:
54
+ raise NodeException(f"Invalid HTTP method '{method}'", code=WorkflowErrorCode.INVALID_INPUTS)
55
+ else:
56
+ raise NodeException(
57
+ f"Method must be either APIRequestMethod enum or string, got {type(method)}",
58
+ code=WorkflowErrorCode.INVALID_INPUTS,
59
+ )
60
+ return method_str
61
+
46
62
  def _validate(self) -> None:
47
63
  if not self.url or not isinstance(self.url, str) or not self.url.strip():
48
64
  raise NodeException("URL is required and must be a non-empty string", code=WorkflowErrorCode.INVALID_INPUTS)
@@ -55,7 +71,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
55
71
  def _run(
56
72
  self,
57
73
  url: str,
58
- method: Optional[APIRequestMethod] = APIRequestMethod.GET,
74
+ method: Optional[Union[APIRequestMethod, MethodEnum]] = APIRequestMethod.GET,
59
75
  data: Optional[Union[str, Any]] = None,
60
76
  json: Any = None,
61
77
  headers: Any = None,
@@ -64,23 +80,25 @@ class BaseAPINode(BaseNode, Generic[StateType]):
64
80
  ) -> Outputs:
65
81
  self._validate()
66
82
 
83
+ normalized_method = self._normalize_http_method(method) if method is not None else APIRequestMethod.GET.value
84
+
67
85
  vellum_instance = False
68
86
  for header in headers or {}:
69
87
  if isinstance(headers[header], VellumSecret):
70
88
  vellum_instance = True
71
89
  if vellum_instance or bearer_token:
72
- return self._vellum_execute_api(bearer_token, json, headers, method, url, timeout)
90
+ return self._vellum_execute_api(bearer_token, json, headers, normalized_method, url, timeout)
73
91
  else:
74
- return self._local_execute_api(data, headers, json, method, url, timeout)
92
+ return self._local_execute_api(data, headers, json, normalized_method, url, timeout)
75
93
 
76
94
  def _local_execute_api(self, data, headers, json, method, url, timeout):
77
95
  try:
78
96
  if data is not None:
79
- prepped = Request(method=method.value, url=url, data=data, headers=headers).prepare()
97
+ prepped = Request(method=method, url=url, data=data, headers=headers).prepare()
80
98
  elif json is not None:
81
- prepped = Request(method=method.value, url=url, json=json, headers=headers).prepare()
99
+ prepped = Request(method=method, url=url, json=json, headers=headers).prepare()
82
100
  else:
83
- prepped = Request(method=method.value, url=url, headers=headers).prepare()
101
+ prepped = Request(method=method, url=url, headers=headers).prepare()
84
102
  except Exception as e:
85
103
  raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.PROVIDER_ERROR)
86
104
  try:
@@ -110,7 +128,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
110
128
  try:
111
129
  vellum_response = self._context.vellum_client.execute_api(
112
130
  url=url,
113
- method=method.value,
131
+ method=method,
114
132
  body=data,
115
133
  headers=headers,
116
134
  bearer_token=client_vellum_secret,
@@ -0,0 +1,47 @@
1
+ import pytest
2
+
3
+ from vellum.client.types.execute_api_response import ExecuteApiResponse
4
+ from vellum.workflows.constants import APIRequestMethod
5
+ from vellum.workflows.errors.types import WorkflowErrorCode
6
+ from vellum.workflows.exceptions import NodeException
7
+ from vellum.workflows.nodes.displayable.bases.api_node.node import BaseAPINode
8
+ from vellum.workflows.types.core import VellumSecret
9
+
10
+
11
+ @pytest.mark.parametrize("method_value", ["GET", "get", APIRequestMethod.GET])
12
+ def test_api_node_with_string_method(method_value, vellum_client):
13
+ class TestAPINode(BaseAPINode):
14
+ method = method_value
15
+ url = "https://example.com"
16
+ headers = {"Authorization": VellumSecret(name="API_KEY")}
17
+
18
+ mock_response = ExecuteApiResponse(
19
+ json_={"status": "success"},
20
+ headers={"content-type": "application/json"},
21
+ status_code=200,
22
+ text='{"status": "success"}',
23
+ )
24
+ vellum_client.execute_api.return_value = mock_response
25
+
26
+ node = TestAPINode()
27
+ result = node.run()
28
+
29
+ assert result.status_code == 200
30
+
31
+ vellum_client.execute_api.assert_called_once()
32
+ call_args = vellum_client.execute_api.call_args
33
+ assert call_args[1]["method"] == "GET"
34
+
35
+
36
+ def test_api_node_with_invalid_method():
37
+ class TestAPINode(BaseAPINode):
38
+ method = "INVALID_METHOD"
39
+ url = "https://example.com"
40
+
41
+ node = TestAPINode()
42
+
43
+ with pytest.raises(NodeException) as exc_info:
44
+ node.run()
45
+
46
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
47
+ assert "Invalid HTTP method 'INVALID_METHOD'" == str(exc_info.value)