vellum-ai 1.0.5__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. vellum/__init__.py +0 -8
  2. vellum/client/core/client_wrapper.py +2 -2
  3. vellum/client/types/__init__.py +0 -8
  4. vellum/client/types/organization_read.py +1 -2
  5. vellum/workflows/events/context.py +111 -0
  6. vellum/workflows/integrations/__init__.py +0 -0
  7. vellum/workflows/integrations/composio_service.py +138 -0
  8. vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +8 -2
  9. vellum/workflows/nodes/displayable/bases/api_node/node.py +36 -9
  10. vellum/workflows/nodes/displayable/bases/api_node/tests/__init__.py +0 -0
  11. vellum/workflows/nodes/displayable/bases/api_node/tests/test_node.py +124 -0
  12. vellum/workflows/nodes/displayable/tool_calling_node/node.py +2 -2
  13. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_composio_service.py +63 -58
  14. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +147 -2
  15. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +61 -41
  16. vellum/workflows/types/definition.py +4 -2
  17. vellum/workflows/utils/functions.py +29 -2
  18. vellum/workflows/utils/tests/test_functions.py +115 -1
  19. {vellum_ai-1.0.5.dist-info → vellum_ai-1.0.7.dist-info}/METADATA +1 -3
  20. {vellum_ai-1.0.5.dist-info → vellum_ai-1.0.7.dist-info}/RECORD +29 -33
  21. vellum_cli/push.py +11 -2
  22. vellum_cli/tests/test_push.py +57 -1
  23. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -0
  24. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +16 -0
  25. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_composio_serialization.py +3 -0
  26. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +8 -2
  27. vellum/client/types/name_enum.py +0 -7
  28. vellum/client/types/organization_limit_config.py +0 -25
  29. vellum/client/types/quota.py +0 -22
  30. vellum/client/types/vembda_service_tier_enum.py +0 -5
  31. vellum/types/name_enum.py +0 -3
  32. vellum/types/organization_limit_config.py +0 -3
  33. vellum/types/quota.py +0 -3
  34. vellum/types/vembda_service_tier_enum.py +0 -3
  35. vellum/workflows/nodes/displayable/tool_calling_node/composio_service.py +0 -83
  36. {vellum_ai-1.0.5.dist-info → vellum_ai-1.0.7.dist-info}/LICENSE +0 -0
  37. {vellum_ai-1.0.5.dist-info → vellum_ai-1.0.7.dist-info}/WHEEL +0 -0
  38. {vellum_ai-1.0.5.dist-info → vellum_ai-1.0.7.dist-info}/entry_points.txt +0 -0
vellum/__init__.py CHANGED
@@ -217,7 +217,6 @@ from .client.types import (
217
217
  MlModelRead,
218
218
  MlModelUsage,
219
219
  MlModelUsageWrapper,
220
- NameEnum,
221
220
  NamedScenarioInputChatHistoryVariableValueRequest,
222
221
  NamedScenarioInputJsonVariableValueRequest,
223
222
  NamedScenarioInputRequest,
@@ -289,7 +288,6 @@ from .client.types import (
289
288
  OpenAiVectorizerTextEmbedding3SmallRequest,
290
289
  OpenAiVectorizerTextEmbeddingAda002,
291
290
  OpenAiVectorizerTextEmbeddingAda002Request,
292
- OrganizationLimitConfig,
293
291
  OrganizationRead,
294
292
  PaginatedContainerImageReadList,
295
293
  PaginatedDeploymentReleaseTagReadList,
@@ -330,7 +328,6 @@ from .client.types import (
330
328
  PromptRequestStringInput,
331
329
  PromptSettings,
332
330
  PromptVersionBuildConfigSandbox,
333
- Quota,
334
331
  RawPromptExecutionOverridesRequest,
335
332
  ReductoChunkerConfig,
336
333
  ReductoChunkerConfigRequest,
@@ -528,7 +525,6 @@ from .client.types import (
528
525
  VellumVariableExtensions,
529
526
  VellumVariableType,
530
527
  VellumWorkflowExecutionEvent,
531
- VembdaServiceTierEnum,
532
528
  WorkflowDeploymentEventExecutionsResponse,
533
529
  WorkflowDeploymentHistoryItem,
534
530
  WorkflowDeploymentParentContext,
@@ -868,7 +864,6 @@ __all__ = [
868
864
  "MlModelRead",
869
865
  "MlModelUsage",
870
866
  "MlModelUsageWrapper",
871
- "NameEnum",
872
867
  "NamedScenarioInputChatHistoryVariableValueRequest",
873
868
  "NamedScenarioInputJsonVariableValueRequest",
874
869
  "NamedScenarioInputRequest",
@@ -941,7 +936,6 @@ __all__ = [
941
936
  "OpenAiVectorizerTextEmbedding3SmallRequest",
942
937
  "OpenAiVectorizerTextEmbeddingAda002",
943
938
  "OpenAiVectorizerTextEmbeddingAda002Request",
944
- "OrganizationLimitConfig",
945
939
  "OrganizationRead",
946
940
  "PaginatedContainerImageReadList",
947
941
  "PaginatedDeploymentReleaseTagReadList",
@@ -982,7 +976,6 @@ __all__ = [
982
976
  "PromptRequestStringInput",
983
977
  "PromptSettings",
984
978
  "PromptVersionBuildConfigSandbox",
985
- "Quota",
986
979
  "RawPromptExecutionOverridesRequest",
987
980
  "ReductoChunkerConfig",
988
981
  "ReductoChunkerConfigRequest",
@@ -1182,7 +1175,6 @@ __all__ = [
1182
1175
  "VellumVariableExtensions",
1183
1176
  "VellumVariableType",
1184
1177
  "VellumWorkflowExecutionEvent",
1185
- "VembdaServiceTierEnum",
1186
1178
  "WorkflowDeploymentEventExecutionsResponse",
1187
1179
  "WorkflowDeploymentHistoryItem",
1188
1180
  "WorkflowDeploymentParentContext",
@@ -25,10 +25,10 @@ class BaseClientWrapper:
25
25
 
26
26
  def get_headers(self) -> typing.Dict[str, str]:
27
27
  headers: typing.Dict[str, str] = {
28
- "User-Agent": "vellum-ai/1.0.5",
28
+ "User-Agent": "vellum-ai/1.0.7",
29
29
  "X-Fern-Language": "Python",
30
30
  "X-Fern-SDK-Name": "vellum-ai",
31
- "X-Fern-SDK-Version": "1.0.5",
31
+ "X-Fern-SDK-Version": "1.0.7",
32
32
  }
33
33
  if self._api_version is not None:
34
34
  headers["X-API-Version"] = self._api_version
@@ -225,7 +225,6 @@ from .metric_node_result import MetricNodeResult
225
225
  from .ml_model_read import MlModelRead
226
226
  from .ml_model_usage import MlModelUsage
227
227
  from .ml_model_usage_wrapper import MlModelUsageWrapper
228
- from .name_enum import NameEnum
229
228
  from .named_scenario_input_chat_history_variable_value_request import NamedScenarioInputChatHistoryVariableValueRequest
230
229
  from .named_scenario_input_json_variable_value_request import NamedScenarioInputJsonVariableValueRequest
231
230
  from .named_scenario_input_request import NamedScenarioInputRequest
@@ -297,7 +296,6 @@ from .open_ai_vectorizer_text_embedding_3_small import OpenAiVectorizerTextEmbed
297
296
  from .open_ai_vectorizer_text_embedding_3_small_request import OpenAiVectorizerTextEmbedding3SmallRequest
298
297
  from .open_ai_vectorizer_text_embedding_ada_002 import OpenAiVectorizerTextEmbeddingAda002
299
298
  from .open_ai_vectorizer_text_embedding_ada_002_request import OpenAiVectorizerTextEmbeddingAda002Request
300
- from .organization_limit_config import OrganizationLimitConfig
301
299
  from .organization_read import OrganizationRead
302
300
  from .paginated_container_image_read_list import PaginatedContainerImageReadList
303
301
  from .paginated_deployment_release_tag_read_list import PaginatedDeploymentReleaseTagReadList
@@ -338,7 +336,6 @@ from .prompt_request_json_input import PromptRequestJsonInput
338
336
  from .prompt_request_string_input import PromptRequestStringInput
339
337
  from .prompt_settings import PromptSettings
340
338
  from .prompt_version_build_config_sandbox import PromptVersionBuildConfigSandbox
341
- from .quota import Quota
342
339
  from .raw_prompt_execution_overrides_request import RawPromptExecutionOverridesRequest
343
340
  from .reducto_chunker_config import ReductoChunkerConfig
344
341
  from .reducto_chunker_config_request import ReductoChunkerConfigRequest
@@ -552,7 +549,6 @@ from .vellum_variable import VellumVariable
552
549
  from .vellum_variable_extensions import VellumVariableExtensions
553
550
  from .vellum_variable_type import VellumVariableType
554
551
  from .vellum_workflow_execution_event import VellumWorkflowExecutionEvent
555
- from .vembda_service_tier_enum import VembdaServiceTierEnum
556
552
  from .workflow_deployment_event_executions_response import WorkflowDeploymentEventExecutionsResponse
557
553
  from .workflow_deployment_history_item import WorkflowDeploymentHistoryItem
558
554
  from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
@@ -848,7 +844,6 @@ __all__ = [
848
844
  "MlModelRead",
849
845
  "MlModelUsage",
850
846
  "MlModelUsageWrapper",
851
- "NameEnum",
852
847
  "NamedScenarioInputChatHistoryVariableValueRequest",
853
848
  "NamedScenarioInputJsonVariableValueRequest",
854
849
  "NamedScenarioInputRequest",
@@ -920,7 +915,6 @@ __all__ = [
920
915
  "OpenAiVectorizerTextEmbedding3SmallRequest",
921
916
  "OpenAiVectorizerTextEmbeddingAda002",
922
917
  "OpenAiVectorizerTextEmbeddingAda002Request",
923
- "OrganizationLimitConfig",
924
918
  "OrganizationRead",
925
919
  "PaginatedContainerImageReadList",
926
920
  "PaginatedDeploymentReleaseTagReadList",
@@ -961,7 +955,6 @@ __all__ = [
961
955
  "PromptRequestStringInput",
962
956
  "PromptSettings",
963
957
  "PromptVersionBuildConfigSandbox",
964
- "Quota",
965
958
  "RawPromptExecutionOverridesRequest",
966
959
  "ReductoChunkerConfig",
967
960
  "ReductoChunkerConfigRequest",
@@ -1159,7 +1152,6 @@ __all__ = [
1159
1152
  "VellumVariableExtensions",
1160
1153
  "VellumVariableType",
1161
1154
  "VellumWorkflowExecutionEvent",
1162
- "VembdaServiceTierEnum",
1163
1155
  "WorkflowDeploymentEventExecutionsResponse",
1164
1156
  "WorkflowDeploymentHistoryItem",
1165
1157
  "WorkflowDeploymentParentContext",
@@ -3,7 +3,6 @@
3
3
  from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .new_member_join_behavior_enum import NewMemberJoinBehaviorEnum
6
- from .organization_limit_config import OrganizationLimitConfig
7
6
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
7
  import pydantic
9
8
 
@@ -13,7 +12,7 @@ class OrganizationRead(UniversalBaseModel):
13
12
  name: str
14
13
  allow_staff_access: typing.Optional[bool] = None
15
14
  new_member_join_behavior: NewMemberJoinBehaviorEnum
16
- limit_config: OrganizationLimitConfig
15
+ limit_config: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = None
17
16
 
18
17
  if IS_PYDANTIC_V2:
19
18
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -0,0 +1,111 @@
1
+ """Monitoring execution context for workflow tracing."""
2
+
3
+ import threading
4
+ from uuid import UUID
5
+ from typing import Dict, Optional
6
+
7
+ from vellum.workflows.context import ExecutionContext
8
+
9
+ DEFAULT_TRACE_ID = UUID("00000000-0000-0000-0000-000000000000")
10
+
11
+ # Thread-local storage for monitoring execution context
12
+ _monitoring_execution_context: threading.local = threading.local()
13
+ # Thread-local storage for current span_id
14
+ _current_span_id: threading.local = threading.local()
15
+
16
+
17
+ class _MonitoringContextStore:
18
+ """
19
+ thread-safe storage for monitoring contexts.
20
+ handles context persistence and retrieval across threads.
21
+ relies on the execution context manager for manual retrieval
22
+ """
23
+
24
+ def __init__(self):
25
+ self._lock = threading.Lock()
26
+ self._contexts: Dict[str, ExecutionContext] = {}
27
+ self._thread_contexts: Dict[int, ExecutionContext] = {}
28
+ self._current_trace_id: Optional[UUID] = None
29
+
30
+ def set_current_trace_id(self, trace_id: UUID) -> None:
31
+ """Set the current active trace_id that should be used by all threads."""
32
+ if trace_id != DEFAULT_TRACE_ID:
33
+ with self._lock:
34
+ self._current_trace_id = trace_id
35
+
36
+ def get_current_trace_id(self) -> Optional[UUID]:
37
+ """Get the current active trace_id that should be used by all threads."""
38
+ with self._lock:
39
+ return self._current_trace_id
40
+
41
+ def set_current_span_id(self, span_id: UUID) -> None:
42
+ """Set the current active span_id for this thread."""
43
+ _current_span_id.span_id = span_id
44
+
45
+ def get_current_span_id(self) -> Optional[UUID]:
46
+ """Get the current active span_id for this thread."""
47
+ return getattr(_current_span_id, "span_id", None)
48
+
49
+ def store_context(self, context: Optional[ExecutionContext]) -> None:
50
+ """Store monitoring parent context using multiple keys for reliable retrieval."""
51
+ if not context or context.parent_context is None:
52
+ return
53
+
54
+ thread_id = threading.get_ident()
55
+ trace_id = self.get_current_trace_id()
56
+ if context.trace_id != DEFAULT_TRACE_ID and trace_id is None:
57
+ self.set_current_trace_id(context.trace_id)
58
+
59
+ with self._lock:
60
+ # Use trace:span:thread for unique context storage
61
+ trace_span_thread_key = (
62
+ f"trace:{str(trace_id)}:span:{str(context.parent_context.span_id)}:thread:{thread_id}"
63
+ )
64
+ self._contexts[trace_span_thread_key] = context
65
+
66
+ def retrieve_context(self, trace_id: UUID, span_id: Optional[UUID] = None) -> Optional[ExecutionContext]:
67
+ """Retrieve monitoring parent context with multiple fallback strategies."""
68
+ thread_id = threading.get_ident()
69
+ with self._lock:
70
+ if not span_id:
71
+ span_id = getattr(_current_span_id, "span_id", None)
72
+ if not span_id:
73
+ return None
74
+
75
+ span_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{thread_id}"
76
+ if span_key in self._contexts:
77
+ result = self._contexts[span_key]
78
+ return result
79
+
80
+ return None
81
+
82
+
83
+ # Global instance for cross-boundary context persistence
84
+ _monitoring_context_store = _MonitoringContextStore()
85
+
86
+
87
+ def get_monitoring_execution_context() -> ExecutionContext:
88
+ """Get the current monitoring execution context, with intelligent fallback."""
89
+ if hasattr(_monitoring_execution_context, "context"):
90
+ context = _monitoring_execution_context.context
91
+ if context.trace_id != DEFAULT_TRACE_ID and context.parent_context:
92
+ return context
93
+
94
+ # If no thread-local context, try to restore from global store using current trace_id
95
+ trace_id = _monitoring_context_store.get_current_trace_id()
96
+ span_id = _current_span_id.span_id if hasattr(_current_span_id, "span_id") else None
97
+ if trace_id:
98
+ if trace_id != DEFAULT_TRACE_ID:
99
+ context = _monitoring_context_store.retrieve_context(trace_id, span_id)
100
+ if context:
101
+ _monitoring_execution_context.context = context
102
+ return context
103
+ return ExecutionContext()
104
+
105
+
106
+ def set_monitoring_execution_context(context: ExecutionContext) -> None:
107
+ """Set the current monitoring execution context and persist it for cross-boundary access."""
108
+ _monitoring_execution_context.context = context
109
+
110
+ if context.trace_id and context.parent_context:
111
+ _monitoring_context_store.store_context(context)
File without changes
@@ -0,0 +1,138 @@
1
+ from dataclasses import dataclass
2
+ import logging
3
+ import os
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import requests
7
+
8
+ from vellum.workflows.exceptions import NodeException
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class ConnectionInfo:
15
+ """Information about a user's authorized connection"""
16
+
17
+ connection_id: str
18
+ integration_name: str
19
+ created_at: str
20
+ updated_at: str
21
+ status: str = "ACTIVE"
22
+
23
+
24
+ class ComposioService:
25
+ """Composio API client for managing connections and executing tools"""
26
+
27
+ def __init__(self, api_key: Optional[str] = None):
28
+ # If no API key provided, look it up from environment variables
29
+ if api_key is None:
30
+ api_key = self._get_api_key_from_env()
31
+
32
+ if not api_key:
33
+ common_env_var_names = ["COMPOSIO_API_KEY", "COMPOSIO_KEY"]
34
+ raise NodeException(
35
+ "No Composio API key found. "
36
+ "Please provide an api_key parameter or set one of these environment variables: "
37
+ + ", ".join(common_env_var_names)
38
+ )
39
+
40
+ self.api_key = api_key
41
+ self.base_url = "https://backend.composio.dev/api/v3"
42
+
43
+ @staticmethod
44
+ def _get_api_key_from_env() -> Optional[str]:
45
+ """Get Composio API key from environment variables"""
46
+ common_env_var_names = ["COMPOSIO_API_KEY", "COMPOSIO_KEY"]
47
+
48
+ for env_var_name in common_env_var_names:
49
+ value = os.environ.get(env_var_name)
50
+ if value:
51
+ return value
52
+ return None
53
+
54
+ def _make_request(
55
+ self, endpoint: str, method: str = "GET", params: Optional[dict] = None, json_data: Optional[dict] = None
56
+ ) -> dict:
57
+ """Make a request to the Composio API"""
58
+ headers = {
59
+ "x-api-key": self.api_key,
60
+ "Content-Type": "application/json",
61
+ }
62
+
63
+ url = f"{self.base_url}{endpoint}"
64
+
65
+ try:
66
+ if method == "GET":
67
+ response = requests.get(url, headers=headers, params=params or {}, timeout=30)
68
+ elif method == "POST":
69
+ response = requests.post(url, headers=headers, json=json_data or {}, timeout=30)
70
+ else:
71
+ raise ValueError(f"Unsupported HTTP method: {method}")
72
+
73
+ response.raise_for_status()
74
+ return response.json()
75
+ except Exception as e:
76
+ raise NodeException(f"Composio API request failed: {e}")
77
+
78
+ def get_user_connections(self) -> List[ConnectionInfo]:
79
+ """Get all authorized connections for the user"""
80
+ response = self._make_request("/connected_accounts")
81
+
82
+ return [
83
+ ConnectionInfo(
84
+ connection_id=item.get("id"),
85
+ integration_name=item.get("toolkit", {}).get("slug", ""),
86
+ status=item.get("status", "ACTIVE"),
87
+ created_at=item.get("created_at", ""),
88
+ updated_at=item.get("updated_at", ""),
89
+ )
90
+ for item in response.get("items", [])
91
+ ]
92
+
93
+ def get_tool_by_slug(self, tool_slug: str) -> Dict[str, Any]:
94
+ """Get detailed information about a tool using its slug identifier
95
+
96
+ Args:
97
+ tool_slug: The unique slug identifier of the tool
98
+
99
+ Returns:
100
+ Dictionary containing detailed tool information including:
101
+ - slug, name, description
102
+ - toolkit info (slug, name, logo)
103
+ - input_parameters, output_parameters
104
+ - no_auth, available_versions, version
105
+ - scopes, tags, deprecated info
106
+
107
+ Raises:
108
+ NodeException: If tool not found (404), unauthorized (401), or other API errors
109
+ """
110
+ endpoint = f"/tools/{tool_slug}"
111
+
112
+ try:
113
+ response = self._make_request(endpoint, method="GET")
114
+ logger.info(f"Retrieved tool details for slug '{tool_slug}': {response}")
115
+ return response
116
+ except Exception as e:
117
+ # Enhanced error handling for specific cases
118
+ error_message = str(e)
119
+ if "404" in error_message:
120
+ raise NodeException(f"Tool with slug '{tool_slug}' not found in Composio")
121
+ elif "401" in error_message:
122
+ raise NodeException(f"Unauthorized access to tool '{tool_slug}'. Check your Composio API key.")
123
+ else:
124
+ raise NodeException(f"Failed to retrieve tool details for '{tool_slug}': {error_message}")
125
+
126
+ def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
127
+ """Execute a tool using direct API request
128
+
129
+ Args:
130
+ tool_name: The name of the tool to execute (e.g., "HACKERNEWS_GET_USER")
131
+ arguments: Dictionary of arguments to pass to the tool
132
+
133
+ Returns:
134
+ The result of the tool execution
135
+ """
136
+ endpoint = f"/tools/execute/{tool_name}"
137
+ response = self._make_request(endpoint, method="POST", json_data={"arguments": arguments})
138
+ return response.get("data", response)
@@ -16,6 +16,7 @@ def test_run_workflow__secrets(vellum_client):
16
16
  json_={"data": [1, 2, 3]},
17
17
  headers={"X-Response-Header": "bar"},
18
18
  )
19
+ vellum_client._client_wrapper.get_headers.return_value = {"User-Agent": "vellum-ai/1.0.6"}
19
20
 
20
21
  class SimpleBaseAPINode(APINode):
21
22
  method = APIRequestMethod.POST
@@ -35,7 +36,9 @@ def test_run_workflow__secrets(vellum_client):
35
36
  assert vellum_client.execute_api.call_count == 1
36
37
  assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
37
38
  assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
38
- assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
39
+ headers = vellum_client.execute_api.call_args.kwargs["headers"]
40
+ assert headers["X-Test-Header"] == "foo"
41
+ assert "vellum-ai" in headers.get("User-Agent", "")
39
42
  bearer_token = vellum_client.execute_api.call_args.kwargs["bearer_token"]
40
43
  assert bearer_token == ClientVellumSecret(name="secret")
41
44
  assert terminal.headers == {"X-Response-Header": "bar"}
@@ -44,6 +47,7 @@ def test_run_workflow__secrets(vellum_client):
44
47
  def test_api_node_raises_error_when_api_call_fails(vellum_client):
45
48
  # GIVEN an API call that fails
46
49
  vellum_client.execute_api.side_effect = ApiError(status_code=400, body="API Error")
50
+ vellum_client._client_wrapper.get_headers.return_value = {"User-Agent": "vellum-ai/1.0.6"}
47
51
 
48
52
  class SimpleAPINode(APINode):
49
53
  method = APIRequestMethod.GET
@@ -70,7 +74,9 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
70
74
  assert vellum_client.execute_api.call_count == 1
71
75
  assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
72
76
  assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
73
- assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
77
+ headers = vellum_client.execute_api.call_args.kwargs["headers"]
78
+ assert headers["X-Test-Header"] == "foo"
79
+ assert "vellum-ai" in headers.get("User-Agent", "")
74
80
 
75
81
 
76
82
  def test_api_node_defaults_to_get_method(vellum_client):
@@ -5,6 +5,7 @@ from requests.exceptions import JSONDecodeError
5
5
 
6
6
  from vellum.client import ApiError
7
7
  from vellum.client.core.request_options import RequestOptions
8
+ from vellum.client.types.method_enum import MethodEnum
8
9
  from vellum.client.types.vellum_secret import VellumSecret as ClientVellumSecret
9
10
  from vellum.workflows.constants import APIRequestMethod
10
11
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -20,7 +21,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
20
21
  Used to execute an API call.
21
22
 
22
23
  url: str - The URL to send the request to.
23
- method: APIRequestMethod - The HTTP method to use for the request.
24
+ method: Union[APIRequestMethod, str] - The HTTP method to use for the request.
24
25
  data: Optional[str] - The data to send in the request body.
25
26
  json: Optional["JsonObject"] - The JSON data to send in the request body.
26
27
  headers: Optional[Dict[str, Union[str, VellumSecret]]] - The headers to send in the request.
@@ -31,7 +32,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
31
32
  merge_behavior = MergeBehavior.AWAIT_ANY
32
33
 
33
34
  url: str = ""
34
- method: Optional[APIRequestMethod] = APIRequestMethod.GET
35
+ method: Optional[Union[APIRequestMethod, MethodEnum]] = APIRequestMethod.GET
35
36
  data: Optional[str] = None
36
37
  json: Optional[Json] = None
37
38
  headers: Optional[Dict[str, Union[str, VellumSecret]]] = None
@@ -43,6 +44,21 @@ class BaseAPINode(BaseNode, Generic[StateType]):
43
44
  status_code: int
44
45
  text: str
45
46
 
47
+ def _normalize_http_method(self, method: Union[APIRequestMethod, MethodEnum]) -> str:
48
+ if isinstance(method, APIRequestMethod):
49
+ method_str = method.value
50
+ elif isinstance(method, str):
51
+ method_str = method.upper()
52
+ valid_methods = {m.value for m in APIRequestMethod}
53
+ if method_str not in valid_methods:
54
+ raise NodeException(f"Invalid HTTP method '{method}'", code=WorkflowErrorCode.INVALID_INPUTS)
55
+ else:
56
+ raise NodeException(
57
+ f"Method must be either APIRequestMethod enum or string, got {type(method)}",
58
+ code=WorkflowErrorCode.INVALID_INPUTS,
59
+ )
60
+ return method_str
61
+
46
62
  def _validate(self) -> None:
47
63
  if not self.url or not isinstance(self.url, str) or not self.url.strip():
48
64
  raise NodeException("URL is required and must be a non-empty string", code=WorkflowErrorCode.INVALID_INPUTS)
@@ -55,7 +71,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
55
71
  def _run(
56
72
  self,
57
73
  url: str,
58
- method: Optional[APIRequestMethod] = APIRequestMethod.GET,
74
+ method: Optional[Union[APIRequestMethod, MethodEnum]] = APIRequestMethod.GET,
59
75
  data: Optional[Union[str, Any]] = None,
60
76
  json: Any = None,
61
77
  headers: Any = None,
@@ -64,23 +80,29 @@ class BaseAPINode(BaseNode, Generic[StateType]):
64
80
  ) -> Outputs:
65
81
  self._validate()
66
82
 
83
+ normalized_method = self._normalize_http_method(method) if method is not None else APIRequestMethod.GET.value
84
+
67
85
  vellum_instance = False
68
86
  for header in headers or {}:
69
87
  if isinstance(headers[header], VellumSecret):
70
88
  vellum_instance = True
71
89
  if vellum_instance or bearer_token:
72
- return self._vellum_execute_api(bearer_token, json, headers, method, url, timeout)
90
+ return self._vellum_execute_api(bearer_token, json, headers, normalized_method, url, timeout)
73
91
  else:
74
- return self._local_execute_api(data, headers, json, method, url, timeout)
92
+ return self._local_execute_api(data, headers, json, normalized_method, url, timeout)
75
93
 
76
94
  def _local_execute_api(self, data, headers, json, method, url, timeout):
95
+ headers = headers or {}
96
+ if "User-Agent" not in headers:
97
+ client_headers = self._context.vellum_client._client_wrapper.get_headers()
98
+ headers["User-Agent"] = client_headers.get("User-Agent")
77
99
  try:
78
100
  if data is not None:
79
- prepped = Request(method=method.value, url=url, data=data, headers=headers).prepare()
101
+ prepped = Request(method=method, url=url, data=data, headers=headers).prepare()
80
102
  elif json is not None:
81
- prepped = Request(method=method.value, url=url, json=json, headers=headers).prepare()
103
+ prepped = Request(method=method, url=url, json=json, headers=headers).prepare()
82
104
  else:
83
- prepped = Request(method=method.value, url=url, headers=headers).prepare()
105
+ prepped = Request(method=method, url=url, headers=headers).prepare()
84
106
  except Exception as e:
85
107
  raise NodeException(f"Failed to prepare HTTP request: {e}", code=WorkflowErrorCode.PROVIDER_ERROR)
86
108
  try:
@@ -102,6 +124,11 @@ class BaseAPINode(BaseNode, Generic[StateType]):
102
124
  def _vellum_execute_api(self, bearer_token, data, headers, method, url, timeout):
103
125
  client_vellum_secret = ClientVellumSecret(name=bearer_token.name) if bearer_token else None
104
126
 
127
+ headers = headers or {}
128
+ if "User-Agent" not in headers:
129
+ client_headers = self._context.vellum_client._client_wrapper.get_headers()
130
+ headers["User-Agent"] = client_headers.get("User-Agent")
131
+
105
132
  # Create request_options if timeout is specified
106
133
  request_options = None
107
134
  if timeout is not None:
@@ -110,7 +137,7 @@ class BaseAPINode(BaseNode, Generic[StateType]):
110
137
  try:
111
138
  vellum_response = self._context.vellum_client.execute_api(
112
139
  url=url,
113
- method=method.value,
140
+ method=method,
114
141
  body=data,
115
142
  headers=headers,
116
143
  bearer_token=client_vellum_secret,