vellum-ai 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +2 -0
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/types/__init__.py +2 -0
- vellum/client/types/name_enum.py +7 -0
- vellum/client/types/organization_limit_config.py +1 -0
- vellum/client/types/quota.py +2 -1
- vellum/prompts/blocks/compilation.py +5 -1
- vellum/prompts/blocks/tests/test_compilation.py +64 -0
- vellum/types/name_enum.py +3 -0
- vellum/workflows/descriptors/base.py +12 -0
- vellum/workflows/expressions/concat.py +32 -0
- vellum/workflows/expressions/tests/test_concat.py +53 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/node.py +1 -2
- vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +1 -2
- vellum/workflows/nodes/displayable/tool_calling_node/composio_service.py +83 -0
- vellum/workflows/nodes/displayable/tool_calling_node/tests/test_composio_service.py +122 -0
- vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +21 -1
- vellum/workflows/nodes/displayable/tool_calling_node/utils.py +133 -57
- vellum/workflows/types/core.py +2 -2
- vellum/workflows/types/definition.py +20 -1
- vellum/workflows/types/tests/test_definition.py +14 -1
- vellum/workflows/utils/functions.py +13 -1
- vellum/workflows/utils/tests/test_functions.py +32 -1
- {vellum_ai-1.0.3.dist-info → vellum_ai-1.0.5.dist-info}/METADATA +3 -1
- {vellum_ai-1.0.3.dist-info → vellum_ai-1.0.5.dist-info}/RECORD +29 -22
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_composio_serialization.py +86 -0
- {vellum_ai-1.0.3.dist-info → vellum_ai-1.0.5.dist-info}/LICENSE +0 -0
- {vellum_ai-1.0.3.dist-info → vellum_ai-1.0.5.dist-info}/WHEEL +0 -0
- {vellum_ai-1.0.3.dist-info → vellum_ai-1.0.5.dist-info}/entry_points.txt +0 -0
vellum/__init__.py
CHANGED
@@ -217,6 +217,7 @@ from .client.types import (
|
|
217
217
|
MlModelRead,
|
218
218
|
MlModelUsage,
|
219
219
|
MlModelUsageWrapper,
|
220
|
+
NameEnum,
|
220
221
|
NamedScenarioInputChatHistoryVariableValueRequest,
|
221
222
|
NamedScenarioInputJsonVariableValueRequest,
|
222
223
|
NamedScenarioInputRequest,
|
@@ -867,6 +868,7 @@ __all__ = [
|
|
867
868
|
"MlModelRead",
|
868
869
|
"MlModelUsage",
|
869
870
|
"MlModelUsageWrapper",
|
871
|
+
"NameEnum",
|
870
872
|
"NamedScenarioInputChatHistoryVariableValueRequest",
|
871
873
|
"NamedScenarioInputJsonVariableValueRequest",
|
872
874
|
"NamedScenarioInputRequest",
|
@@ -25,10 +25,10 @@ class BaseClientWrapper:
|
|
25
25
|
|
26
26
|
def get_headers(self) -> typing.Dict[str, str]:
|
27
27
|
headers: typing.Dict[str, str] = {
|
28
|
-
"User-Agent": "vellum-ai/1.0.
|
28
|
+
"User-Agent": "vellum-ai/1.0.5",
|
29
29
|
"X-Fern-Language": "Python",
|
30
30
|
"X-Fern-SDK-Name": "vellum-ai",
|
31
|
-
"X-Fern-SDK-Version": "1.0.
|
31
|
+
"X-Fern-SDK-Version": "1.0.5",
|
32
32
|
}
|
33
33
|
if self._api_version is not None:
|
34
34
|
headers["X-API-Version"] = self._api_version
|
vellum/client/types/__init__.py
CHANGED
@@ -225,6 +225,7 @@ from .metric_node_result import MetricNodeResult
|
|
225
225
|
from .ml_model_read import MlModelRead
|
226
226
|
from .ml_model_usage import MlModelUsage
|
227
227
|
from .ml_model_usage_wrapper import MlModelUsageWrapper
|
228
|
+
from .name_enum import NameEnum
|
228
229
|
from .named_scenario_input_chat_history_variable_value_request import NamedScenarioInputChatHistoryVariableValueRequest
|
229
230
|
from .named_scenario_input_json_variable_value_request import NamedScenarioInputJsonVariableValueRequest
|
230
231
|
from .named_scenario_input_request import NamedScenarioInputRequest
|
@@ -847,6 +848,7 @@ __all__ = [
|
|
847
848
|
"MlModelRead",
|
848
849
|
"MlModelUsage",
|
849
850
|
"MlModelUsageWrapper",
|
851
|
+
"NameEnum",
|
850
852
|
"NamedScenarioInputChatHistoryVariableValueRequest",
|
851
853
|
"NamedScenarioInputJsonVariableValueRequest",
|
852
854
|
"NamedScenarioInputRequest",
|
@@ -13,6 +13,7 @@ class OrganizationLimitConfig(UniversalBaseModel):
|
|
13
13
|
prompt_executions_quota: typing.Optional[Quota] = None
|
14
14
|
workflow_executions_quota: typing.Optional[Quota] = None
|
15
15
|
workflow_runtime_seconds_quota: typing.Optional[Quota] = None
|
16
|
+
max_workflow_runtime_seconds: typing.Optional[int] = None
|
16
17
|
|
17
18
|
if IS_PYDANTIC_V2:
|
18
19
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
vellum/client/types/quota.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
from .name_enum import NameEnum
|
4
5
|
import typing
|
5
6
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
6
7
|
import pydantic
|
7
8
|
|
8
9
|
|
9
10
|
class Quota(UniversalBaseModel):
|
10
|
-
name:
|
11
|
+
name: NameEnum
|
11
12
|
value: typing.Optional[int] = None
|
12
13
|
period_seconds: typing.Optional[int] = None
|
13
14
|
|
@@ -105,7 +105,11 @@ def compile_prompt_blocks(
|
|
105
105
|
cache_config=block.cache_config,
|
106
106
|
)
|
107
107
|
)
|
108
|
-
elif compiled_input == "JSON":
|
108
|
+
elif compiled_input.type == "JSON":
|
109
|
+
# Skip empty JSON arrays when there are chat message blocks present
|
110
|
+
if compiled_input.value == [] and any(block.block_type == "CHAT_MESSAGE" for block in compiled_blocks):
|
111
|
+
continue
|
112
|
+
|
109
113
|
compiled_blocks.append(
|
110
114
|
CompiledValuePromptBlock(
|
111
115
|
content=JsonVellumValue(value=compiled_input.value),
|
@@ -10,7 +10,10 @@ from vellum import (
|
|
10
10
|
VariablePromptBlock,
|
11
11
|
VellumVariable,
|
12
12
|
)
|
13
|
+
from vellum.client.types.json_vellum_value import JsonVellumValue
|
13
14
|
from vellum.client.types.number_input import NumberInput
|
15
|
+
from vellum.client.types.prompt_block import PromptBlock
|
16
|
+
from vellum.client.types.prompt_request_json_input import PromptRequestJsonInput
|
14
17
|
from vellum.prompts.blocks.compilation import compile_prompt_blocks
|
15
18
|
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock, CompiledValuePromptBlock
|
16
19
|
|
@@ -146,3 +149,64 @@ def test_compile_prompt_blocks__happy(blocks, inputs, input_variables, expected)
|
|
146
149
|
actual = compile_prompt_blocks(blocks=blocks, inputs=inputs, input_variables=input_variables)
|
147
150
|
|
148
151
|
assert actual == expected
|
152
|
+
|
153
|
+
|
154
|
+
def test_compile_prompt_blocks__empty_json_variable_with_chat_message_blocks():
|
155
|
+
"""Test JSON variable handling logic, specifically the empty array skipping behavior."""
|
156
|
+
|
157
|
+
# GIVEN empty array with chat message blocks
|
158
|
+
blocks_with_chat: list[PromptBlock] = [
|
159
|
+
ChatMessagePromptBlock(
|
160
|
+
chat_role="USER",
|
161
|
+
blocks=[RichTextPromptBlock(blocks=[PlainTextPromptBlock(text="User message")])],
|
162
|
+
),
|
163
|
+
VariablePromptBlock(input_variable="json_data"),
|
164
|
+
]
|
165
|
+
|
166
|
+
inputs_with_empty_json = [PromptRequestJsonInput(key="json_data", value=[], type="JSON")]
|
167
|
+
|
168
|
+
input_variables = [VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="JSON", key="json_data")]
|
169
|
+
|
170
|
+
# THEN the empty JSON array should be skipped when there are chat message blocks
|
171
|
+
expected_with_chat = [
|
172
|
+
CompiledChatMessagePromptBlock(
|
173
|
+
role="USER",
|
174
|
+
blocks=[CompiledValuePromptBlock(content=StringVellumValue(value="User message"))],
|
175
|
+
),
|
176
|
+
]
|
177
|
+
|
178
|
+
actual = compile_prompt_blocks(
|
179
|
+
blocks=blocks_with_chat, inputs=inputs_with_empty_json, input_variables=input_variables
|
180
|
+
)
|
181
|
+
assert actual == expected_with_chat
|
182
|
+
|
183
|
+
|
184
|
+
def test_compile_prompt_blocks__non_empty_json_variable_with_chat_message_blocks():
|
185
|
+
"""Test that non-empty JSON variables are included even when there are chat message blocks."""
|
186
|
+
|
187
|
+
# GIVEN non-empty JSON with chat message blocks
|
188
|
+
blocks_with_chat: list[PromptBlock] = [
|
189
|
+
ChatMessagePromptBlock(
|
190
|
+
chat_role="USER",
|
191
|
+
blocks=[RichTextPromptBlock(blocks=[PlainTextPromptBlock(text="User message")])],
|
192
|
+
),
|
193
|
+
VariablePromptBlock(input_variable="json_data"),
|
194
|
+
]
|
195
|
+
|
196
|
+
inputs_with_non_empty_json = [PromptRequestJsonInput(key="json_data", value={"key": "value"}, type="JSON")]
|
197
|
+
|
198
|
+
input_variables = [VellumVariable(id="901ec2d6-430c-4341-b963-ca689006f5cc", type="JSON", key="json_data")]
|
199
|
+
|
200
|
+
# THEN the non-empty JSON should be included
|
201
|
+
expected_with_non_empty = [
|
202
|
+
CompiledChatMessagePromptBlock(
|
203
|
+
role="USER",
|
204
|
+
blocks=[CompiledValuePromptBlock(content=StringVellumValue(value="User message"))],
|
205
|
+
),
|
206
|
+
CompiledValuePromptBlock(content=JsonVellumValue(value={"key": "value"})),
|
207
|
+
]
|
208
|
+
|
209
|
+
actual = compile_prompt_blocks(
|
210
|
+
blocks=blocks_with_chat, inputs=inputs_with_non_empty_json, input_variables=input_variables
|
211
|
+
)
|
212
|
+
assert actual == expected_with_non_empty
|
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|
6
6
|
from vellum.workflows.expressions.begins_with import BeginsWithExpression
|
7
7
|
from vellum.workflows.expressions.between import BetweenExpression
|
8
8
|
from vellum.workflows.expressions.coalesce_expression import CoalesceExpression
|
9
|
+
from vellum.workflows.expressions.concat import ConcatExpression
|
9
10
|
from vellum.workflows.expressions.contains import ContainsExpression
|
10
11
|
from vellum.workflows.expressions.does_not_begin_with import DoesNotBeginWithExpression
|
11
12
|
from vellum.workflows.expressions.does_not_contain import DoesNotContainExpression
|
@@ -364,3 +365,14 @@ class BaseDescriptor(Generic[_T]):
|
|
364
365
|
from vellum.workflows.expressions.is_error import IsErrorExpression
|
365
366
|
|
366
367
|
return IsErrorExpression(expression=self)
|
368
|
+
|
369
|
+
@overload
|
370
|
+
def concat(self, other: "BaseDescriptor[_O]") -> "ConcatExpression[_T, _O]": ...
|
371
|
+
|
372
|
+
@overload
|
373
|
+
def concat(self, other: _O) -> "ConcatExpression[_T, _O]": ...
|
374
|
+
|
375
|
+
def concat(self, other: "Union[BaseDescriptor[_O], _O]") -> "ConcatExpression[_T, _O]":
|
376
|
+
from vellum.workflows.expressions.concat import ConcatExpression
|
377
|
+
|
378
|
+
return ConcatExpression(lhs=self, rhs=other)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from typing import Generic, TypeVar, Union
|
2
|
+
|
3
|
+
from vellum.workflows.descriptors.base import BaseDescriptor
|
4
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
5
|
+
from vellum.workflows.descriptors.utils import resolve_value
|
6
|
+
from vellum.workflows.state.base import BaseState
|
7
|
+
|
8
|
+
LHS = TypeVar("LHS")
|
9
|
+
RHS = TypeVar("RHS")
|
10
|
+
|
11
|
+
|
12
|
+
class ConcatExpression(BaseDescriptor[list], Generic[LHS, RHS]):
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
*,
|
16
|
+
lhs: Union[BaseDescriptor[LHS], LHS],
|
17
|
+
rhs: Union[BaseDescriptor[RHS], RHS],
|
18
|
+
) -> None:
|
19
|
+
super().__init__(name=f"{lhs} + {rhs}", types=(list,))
|
20
|
+
self._lhs = lhs
|
21
|
+
self._rhs = rhs
|
22
|
+
|
23
|
+
def resolve(self, state: "BaseState") -> list:
|
24
|
+
lval = resolve_value(self._lhs, state)
|
25
|
+
rval = resolve_value(self._rhs, state)
|
26
|
+
|
27
|
+
if not isinstance(lval, list):
|
28
|
+
raise InvalidExpressionException(f"Expected LHS to be a list, got {type(lval)}")
|
29
|
+
if not isinstance(rval, list):
|
30
|
+
raise InvalidExpressionException(f"Expected RHS to be a list, got {type(rval)}")
|
31
|
+
|
32
|
+
return lval + rval
|
@@ -0,0 +1,53 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
4
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
5
|
+
from vellum.workflows.state.base import BaseState
|
6
|
+
|
7
|
+
|
8
|
+
class TestState(BaseState):
|
9
|
+
pass
|
10
|
+
|
11
|
+
|
12
|
+
def test_concat_expression_happy_path():
|
13
|
+
# GIVEN two lists
|
14
|
+
state = TestState()
|
15
|
+
lhs_ref = ConstantValueReference([1, 2, 3])
|
16
|
+
rhs_ref = ConstantValueReference([4, 5, 6])
|
17
|
+
concat_expr = lhs_ref.concat(rhs_ref)
|
18
|
+
|
19
|
+
# WHEN we resolve the expression
|
20
|
+
result = concat_expr.resolve(state)
|
21
|
+
|
22
|
+
# THEN the lists should be concatenated
|
23
|
+
assert result == [1, 2, 3, 4, 5, 6]
|
24
|
+
|
25
|
+
|
26
|
+
def test_concat_expression_lhs_fail():
|
27
|
+
# GIVEN a non-list lhs and a list rhs
|
28
|
+
state = TestState()
|
29
|
+
lhs_ref = ConstantValueReference(0)
|
30
|
+
rhs_ref = ConstantValueReference([4, 5, 6])
|
31
|
+
concat_expr = lhs_ref.concat(rhs_ref)
|
32
|
+
|
33
|
+
# WHEN we attempt to resolve the expression
|
34
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
35
|
+
concat_expr.resolve(state)
|
36
|
+
|
37
|
+
# THEN an exception should be raised
|
38
|
+
assert "Expected LHS to be a list, got <class 'int'>" in str(exc_info.value)
|
39
|
+
|
40
|
+
|
41
|
+
def test_concat_expression_rhs_fail():
|
42
|
+
# GIVEN a list lhs and a non-list rhs
|
43
|
+
state = TestState()
|
44
|
+
lhs_ref = ConstantValueReference([1, 2, 3])
|
45
|
+
rhs_ref = ConstantValueReference(False)
|
46
|
+
concat_expr = lhs_ref.concat(rhs_ref)
|
47
|
+
|
48
|
+
# WHEN we attempt to resolve the expression
|
49
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
50
|
+
concat_expr.resolve(state)
|
51
|
+
|
52
|
+
# THEN an exception should be raised
|
53
|
+
assert "Expected RHS to be a list, got <class 'bool'>" in str(exc_info.value)
|
@@ -64,8 +64,7 @@ class InlinePromptNode(BaseInlinePromptNode[StateType]):
|
|
64
64
|
elif output.type == "FUNCTION_CALL":
|
65
65
|
string_outputs.append(output.value.model_dump_json(indent=4))
|
66
66
|
elif output.type == "THINKING":
|
67
|
-
|
68
|
-
string_outputs.append(output.value.value)
|
67
|
+
continue
|
69
68
|
else:
|
70
69
|
string_outputs.append(output.value.message)
|
71
70
|
|
@@ -66,8 +66,7 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
|
|
66
66
|
elif output.type == "FUNCTION_CALL":
|
67
67
|
string_outputs.append(output.value.model_dump_json(indent=4))
|
68
68
|
elif output.type == "THINKING":
|
69
|
-
|
70
|
-
string_outputs.append(output.value.value)
|
69
|
+
continue
|
71
70
|
else:
|
72
71
|
string_outputs.append(output.value.message)
|
73
72
|
|
@@ -0,0 +1,83 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Any, Dict, List
|
3
|
+
|
4
|
+
from composio import Action, Composio
|
5
|
+
from composio_client import Composio as ComposioClient
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class ConnectionInfo:
|
10
|
+
"""Information about a user's authorized connection"""
|
11
|
+
|
12
|
+
connection_id: str
|
13
|
+
integration_name: str
|
14
|
+
created_at: str
|
15
|
+
updated_at: str
|
16
|
+
status: str = "ACTIVE" # TODO: Use enum if we end up supporting integrations that the user has not yet connected to
|
17
|
+
|
18
|
+
|
19
|
+
class ComposioAccountService:
|
20
|
+
"""Manages user authorized connections using composio-client"""
|
21
|
+
|
22
|
+
def __init__(self, api_key: str):
|
23
|
+
self.client = ComposioClient(api_key=api_key)
|
24
|
+
|
25
|
+
def get_user_connections(self) -> List[ConnectionInfo]:
|
26
|
+
"""Get all authorized connections for the user"""
|
27
|
+
response = self.client.connected_accounts.list()
|
28
|
+
|
29
|
+
return [
|
30
|
+
ConnectionInfo(
|
31
|
+
connection_id=item.id,
|
32
|
+
integration_name=item.toolkit.slug,
|
33
|
+
status=item.status,
|
34
|
+
created_at=item.created_at,
|
35
|
+
updated_at=item.updated_at,
|
36
|
+
)
|
37
|
+
for item in response.items
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
class ComposioCoreService:
|
42
|
+
"""Handles tool execution using composio-core"""
|
43
|
+
|
44
|
+
def __init__(self, api_key: str):
|
45
|
+
self.client = Composio(api_key=api_key)
|
46
|
+
|
47
|
+
def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
48
|
+
"""Execute a tool using composio-core
|
49
|
+
|
50
|
+
Args:
|
51
|
+
tool_name: The name of the tool to execute (e.g., "HACKERNEWS_GET_USER")
|
52
|
+
arguments: Dictionary of arguments to pass to the tool
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The result of the tool execution
|
56
|
+
"""
|
57
|
+
# Convert tool name string to Action enum
|
58
|
+
action = getattr(Action, tool_name)
|
59
|
+
return self.client.actions.execute(action, params=arguments)
|
60
|
+
|
61
|
+
|
62
|
+
class ComposioService:
|
63
|
+
"""Unified interface for Composio operations"""
|
64
|
+
|
65
|
+
def __init__(self, api_key: str):
|
66
|
+
self.accounts = ComposioAccountService(api_key)
|
67
|
+
self.core = ComposioCoreService(api_key)
|
68
|
+
|
69
|
+
def get_user_connections(self) -> List[ConnectionInfo]:
|
70
|
+
"""Get user's authorized connections"""
|
71
|
+
return self.accounts.get_user_connections()
|
72
|
+
|
73
|
+
def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
74
|
+
"""Execute a tool using composio-core
|
75
|
+
|
76
|
+
Args:
|
77
|
+
tool_name: The name of the tool to execute (e.g., "HACKERNEWS_GET_USER")
|
78
|
+
arguments: Dictionary of arguments to pass to the tool
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
The result of the tool execution
|
82
|
+
"""
|
83
|
+
return self.core.execute_tool(tool_name, arguments)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
import pytest
|
2
|
+
from unittest.mock import Mock, patch
|
3
|
+
|
4
|
+
from vellum.workflows.nodes.displayable.tool_calling_node.composio_service import ComposioService, ConnectionInfo
|
5
|
+
|
6
|
+
|
7
|
+
@pytest.fixture
|
8
|
+
def mock_composio_client():
|
9
|
+
"""Mock the Composio client completely"""
|
10
|
+
with patch("vellum.workflows.nodes.displayable.tool_calling_node.composio_service.ComposioClient") as mock_composio:
|
11
|
+
yield mock_composio.return_value
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.fixture
|
15
|
+
def mock_connected_accounts_response():
|
16
|
+
"""Mock response for connected accounts"""
|
17
|
+
mock_item1 = Mock()
|
18
|
+
mock_item1.id = "conn-123"
|
19
|
+
mock_item1.toolkit.slug = "github"
|
20
|
+
mock_item1.status = "ACTIVE"
|
21
|
+
mock_item1.created_at = "2023-01-01T00:00:00Z"
|
22
|
+
mock_item1.updated_at = "2023-01-15T10:30:00Z"
|
23
|
+
|
24
|
+
mock_item2 = Mock()
|
25
|
+
mock_item2.id = "conn-456"
|
26
|
+
mock_item2.toolkit.slug = "slack"
|
27
|
+
mock_item2.status = "ACTIVE"
|
28
|
+
mock_item2.created_at = "2023-01-01T00:00:00Z"
|
29
|
+
mock_item2.updated_at = "2023-01-10T08:00:00Z"
|
30
|
+
|
31
|
+
mock_response = Mock()
|
32
|
+
mock_response.items = [mock_item1, mock_item2]
|
33
|
+
|
34
|
+
return mock_response
|
35
|
+
|
36
|
+
|
37
|
+
@pytest.fixture
|
38
|
+
def mock_composio_core_client():
|
39
|
+
"""Mock the composio-core Composio client"""
|
40
|
+
with patch("vellum.workflows.nodes.displayable.tool_calling_node.composio_service.Composio") as mock_composio:
|
41
|
+
yield mock_composio.return_value
|
42
|
+
|
43
|
+
|
44
|
+
@pytest.fixture
|
45
|
+
def mock_action():
|
46
|
+
"""Mock the Action class and specific actions"""
|
47
|
+
with patch("vellum.workflows.nodes.displayable.tool_calling_node.composio_service.Action") as mock_action_class:
|
48
|
+
# Mock a specific action
|
49
|
+
mock_hackernews_action = Mock()
|
50
|
+
mock_action_class.HACKERNEWS_GET_USER = mock_hackernews_action
|
51
|
+
mock_action_class.GITHUB_GET_USER = Mock()
|
52
|
+
yield mock_action_class
|
53
|
+
|
54
|
+
|
55
|
+
@pytest.fixture
|
56
|
+
def composio_service(mock_composio_client, mock_composio_core_client):
|
57
|
+
"""Create ComposioService with mocked clients"""
|
58
|
+
return ComposioService(api_key="test-key")
|
59
|
+
|
60
|
+
|
61
|
+
class TestComposioAccountService:
|
62
|
+
"""Test suite for ComposioAccountService"""
|
63
|
+
|
64
|
+
def test_get_user_connections_success(
|
65
|
+
self, composio_service, mock_composio_client, mock_connected_accounts_response
|
66
|
+
):
|
67
|
+
"""Test successful retrieval of user connections"""
|
68
|
+
# GIVEN the Composio client returns a valid response with two connections
|
69
|
+
mock_composio_client.connected_accounts.list.return_value = mock_connected_accounts_response
|
70
|
+
|
71
|
+
# WHEN we request user connections
|
72
|
+
result = composio_service.get_user_connections()
|
73
|
+
|
74
|
+
# THEN we get two properly formatted ConnectionInfo objects
|
75
|
+
assert len(result) == 2
|
76
|
+
assert isinstance(result[0], ConnectionInfo)
|
77
|
+
assert result[0].connection_id == "conn-123"
|
78
|
+
assert result[0].integration_name == "github"
|
79
|
+
assert result[0].status == "ACTIVE"
|
80
|
+
assert result[0].created_at == "2023-01-01T00:00:00Z"
|
81
|
+
assert result[0].updated_at == "2023-01-15T10:30:00Z"
|
82
|
+
|
83
|
+
assert result[1].connection_id == "conn-456"
|
84
|
+
assert result[1].integration_name == "slack"
|
85
|
+
assert result[1].status == "ACTIVE"
|
86
|
+
assert result[1].created_at == "2023-01-01T00:00:00Z"
|
87
|
+
assert result[1].updated_at == "2023-01-10T08:00:00Z"
|
88
|
+
|
89
|
+
mock_composio_client.connected_accounts.list.assert_called_once()
|
90
|
+
|
91
|
+
def test_get_user_connections_empty_response(self, composio_service, mock_composio_client):
|
92
|
+
"""Test handling of empty connections response"""
|
93
|
+
# GIVEN the Composio client returns an empty response
|
94
|
+
mock_response = Mock()
|
95
|
+
mock_response.items = []
|
96
|
+
mock_composio_client.connected_accounts.list.return_value = mock_response
|
97
|
+
|
98
|
+
# WHEN we request user connections
|
99
|
+
result = composio_service.get_user_connections()
|
100
|
+
|
101
|
+
# THEN we get an empty list
|
102
|
+
assert result == []
|
103
|
+
|
104
|
+
|
105
|
+
class TestComposioCoreService:
|
106
|
+
"""Test suite for ComposioCoreService"""
|
107
|
+
|
108
|
+
def test_execute_tool_success(self, composio_service, mock_composio_core_client, mock_action):
|
109
|
+
"""Test executing a tool with complex argument structure"""
|
110
|
+
# GIVEN complex arguments and a mock response
|
111
|
+
complex_args = {"filters": {"status": "active"}, "limit": 10, "sort": "created_at"}
|
112
|
+
expected_result = {"items": [], "total": 0}
|
113
|
+
mock_composio_core_client.actions.execute.return_value = expected_result
|
114
|
+
|
115
|
+
# WHEN we execute a tool with complex arguments
|
116
|
+
result = composio_service.execute_tool("HACKERNEWS_GET_USER", complex_args)
|
117
|
+
|
118
|
+
# THEN the arguments are passed through correctly
|
119
|
+
mock_composio_core_client.actions.execute.assert_called_once_with(
|
120
|
+
mock_action.HACKERNEWS_GET_USER, params=complex_args
|
121
|
+
)
|
122
|
+
assert result == expected_result
|
@@ -1,10 +1,12 @@
|
|
1
|
+
import pytest
|
2
|
+
|
1
3
|
from vellum.workflows import BaseWorkflow
|
2
4
|
from vellum.workflows.inputs.base import BaseInputs
|
3
5
|
from vellum.workflows.nodes.bases import BaseNode
|
4
6
|
from vellum.workflows.nodes.displayable.tool_calling_node.utils import get_function_name
|
5
7
|
from vellum.workflows.outputs.base import BaseOutputs
|
6
8
|
from vellum.workflows.state.base import BaseState
|
7
|
-
from vellum.workflows.types.definition import DeploymentDefinition
|
9
|
+
from vellum.workflows.types.definition import ComposioToolDefinition, DeploymentDefinition
|
8
10
|
|
9
11
|
|
10
12
|
def test_get_function_name_callable():
|
@@ -56,3 +58,21 @@ def test_get_function_name_subworkflow_deployment_uuid():
|
|
56
58
|
result = get_function_name(deployment_config)
|
57
59
|
|
58
60
|
assert result == "57f09bebb46340e0bf9ec972e664352f"
|
61
|
+
|
62
|
+
|
63
|
+
@pytest.mark.parametrize(
|
64
|
+
"toolkit,action,description,expected_result",
|
65
|
+
[
|
66
|
+
("SLACK", "SLACK_SEND_MESSAGE", "Send message to Slack", "slack_send_message"),
|
67
|
+
("GMAIL", "GMAIL_CREATE_EMAIL_DRAFT", "Create Gmail draft", "gmail_create_email_draft"),
|
68
|
+
],
|
69
|
+
)
|
70
|
+
def test_get_function_name_composio_tool_definition_various_toolkits(
|
71
|
+
toolkit: str, action: str, description: str, expected_result: str
|
72
|
+
):
|
73
|
+
"""Test ComposioToolDefinition function name generation with various toolkits."""
|
74
|
+
composio_tool = ComposioToolDefinition(toolkit=toolkit, action=action, description=description)
|
75
|
+
|
76
|
+
result = get_function_name(composio_tool)
|
77
|
+
|
78
|
+
assert result == expected_result
|