vellum-ai 1.0.9__py3-none-any.whl → 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +2 -2
- vellum/workflows/descriptors/base.py +31 -1
- vellum/workflows/descriptors/utils.py +19 -1
- vellum/workflows/emitters/__init__.py +2 -0
- vellum/workflows/emitters/base.py +17 -0
- vellum/workflows/emitters/vellum_emitter.py +138 -0
- vellum/workflows/expressions/accessor.py +23 -15
- vellum/workflows/expressions/add.py +41 -0
- vellum/workflows/expressions/length.py +35 -0
- vellum/workflows/expressions/minus.py +41 -0
- vellum/workflows/expressions/tests/test_add.py +72 -0
- vellum/workflows/expressions/tests/test_length.py +38 -0
- vellum/workflows/expressions/tests/test_minus.py +72 -0
- vellum/workflows/integrations/composio_service.py +10 -2
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
- vellum/workflows/nodes/displayable/inline_prompt_node/node.py +2 -2
- vellum/workflows/nodes/displayable/tool_calling_node/node.py +24 -20
- vellum/workflows/nodes/displayable/tool_calling_node/state.py +2 -0
- vellum/workflows/nodes/displayable/tool_calling_node/tests/test_composio_service.py +92 -0
- vellum/workflows/nodes/displayable/tool_calling_node/tests/test_node.py +25 -10
- vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +7 -5
- vellum/workflows/nodes/displayable/tool_calling_node/utils.py +141 -86
- vellum/workflows/types/core.py +3 -5
- vellum/workflows/types/definition.py +2 -6
- vellum/workflows/types/tests/test_definition.py +5 -2
- {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/METADATA +1 -1
- {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/RECORD +34 -27
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_composio_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_inline_workflow_serialization.py +0 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +0 -5
- vellum_ee/workflows/display/utils/expressions.py +12 -0
- {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/LICENSE +0 -0
- {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/WHEEL +0 -0
- {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/entry_points.txt +0 -0
@@ -25,10 +25,10 @@ class BaseClientWrapper:
|
|
25
25
|
|
26
26
|
def get_headers(self) -> typing.Dict[str, str]:
|
27
27
|
headers: typing.Dict[str, str] = {
|
28
|
-
"User-Agent": "vellum-ai/1.0.
|
28
|
+
"User-Agent": "vellum-ai/1.0.11",
|
29
29
|
"X-Fern-Language": "Python",
|
30
30
|
"X-Fern-SDK-Name": "vellum-ai",
|
31
|
-
"X-Fern-SDK-Version": "1.0.
|
31
|
+
"X-Fern-SDK-Version": "1.0.11",
|
32
32
|
}
|
33
33
|
if self._api_version is not None:
|
34
34
|
headers["X-API-Version"] = self._api_version
|
@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, Tuple, Type, TypeVar,
|
|
2
2
|
|
3
3
|
if TYPE_CHECKING:
|
4
4
|
from vellum.workflows.expressions.accessor import AccessorExpression
|
5
|
+
from vellum.workflows.expressions.add import AddExpression
|
5
6
|
from vellum.workflows.expressions.and_ import AndExpression
|
6
7
|
from vellum.workflows.expressions.begins_with import BeginsWithExpression
|
7
8
|
from vellum.workflows.expressions.between import BetweenExpression
|
@@ -26,8 +27,10 @@ if TYPE_CHECKING:
|
|
26
27
|
from vellum.workflows.expressions.is_not_undefined import IsNotUndefinedExpression
|
27
28
|
from vellum.workflows.expressions.is_null import IsNullExpression
|
28
29
|
from vellum.workflows.expressions.is_undefined import IsUndefinedExpression
|
30
|
+
from vellum.workflows.expressions.length import LengthExpression
|
29
31
|
from vellum.workflows.expressions.less_than import LessThanExpression
|
30
32
|
from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
|
33
|
+
from vellum.workflows.expressions.minus import MinusExpression
|
31
34
|
from vellum.workflows.expressions.not_between import NotBetweenExpression
|
32
35
|
from vellum.workflows.expressions.not_in import NotInExpression
|
33
36
|
from vellum.workflows.expressions.or_ import OrExpression
|
@@ -127,7 +130,7 @@ class BaseDescriptor(Generic[_T]):
|
|
127
130
|
|
128
131
|
return CoalesceExpression(lhs=self, rhs=other)
|
129
132
|
|
130
|
-
def __getitem__(self, field: Union[str, int]) -> "AccessorExpression":
|
133
|
+
def __getitem__(self, field: Union[str, int, "BaseDescriptor[str]", "BaseDescriptor[int]"]) -> "AccessorExpression":
|
131
134
|
from vellum.workflows.expressions.accessor import AccessorExpression
|
132
135
|
|
133
136
|
return AccessorExpression(base=self, field=field)
|
@@ -376,3 +379,30 @@ class BaseDescriptor(Generic[_T]):
|
|
376
379
|
from vellum.workflows.expressions.concat import ConcatExpression
|
377
380
|
|
378
381
|
return ConcatExpression(lhs=self, rhs=other)
|
382
|
+
|
383
|
+
def length(self) -> "LengthExpression[_T]":
|
384
|
+
from vellum.workflows.expressions.length import LengthExpression
|
385
|
+
|
386
|
+
return LengthExpression(expression=self)
|
387
|
+
|
388
|
+
@overload
|
389
|
+
def add(self, other: "BaseDescriptor[_O]") -> "AddExpression[_T, _O]": ...
|
390
|
+
|
391
|
+
@overload
|
392
|
+
def add(self, other: _O) -> "AddExpression[_T, _O]": ...
|
393
|
+
|
394
|
+
def add(self, other: "Union[BaseDescriptor[_O], _O]") -> "AddExpression[_T, _O]":
|
395
|
+
from vellum.workflows.expressions.add import AddExpression
|
396
|
+
|
397
|
+
return AddExpression(lhs=self, rhs=other)
|
398
|
+
|
399
|
+
@overload
|
400
|
+
def minus(self, other: "BaseDescriptor[_O]") -> "MinusExpression[_T, _O]": ...
|
401
|
+
|
402
|
+
@overload
|
403
|
+
def minus(self, other: _O) -> "MinusExpression[_T, _O]": ...
|
404
|
+
|
405
|
+
def minus(self, other: "Union[BaseDescriptor[_O], _O]") -> "MinusExpression[_T, _O]":
|
406
|
+
from vellum.workflows.expressions.minus import MinusExpression
|
407
|
+
|
408
|
+
return MinusExpression(lhs=self, rhs=other)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from collections.abc import Mapping
|
2
2
|
import dataclasses
|
3
3
|
import inspect
|
4
|
-
from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, overload
|
4
|
+
from typing import Any, Dict, Optional, Sequence, Set, Type, TypeVar, Union, cast, overload
|
5
|
+
from typing_extensions import TypeGuard
|
5
6
|
|
6
7
|
from pydantic import BaseModel
|
7
8
|
|
@@ -115,3 +116,20 @@ def is_unresolved(value: Any) -> bool:
|
|
115
116
|
return any(is_unresolved(item) for item in value)
|
116
117
|
|
117
118
|
return False
|
119
|
+
|
120
|
+
|
121
|
+
_ResolvedType = TypeVar("_ResolvedType")
|
122
|
+
|
123
|
+
|
124
|
+
def is_resolved_instance(value: Any, type_: Type[_ResolvedType]) -> TypeGuard[_ResolvedType]:
|
125
|
+
"""
|
126
|
+
Checks if a value is an instance of a type or a descriptor that resolves to that type.
|
127
|
+
"""
|
128
|
+
|
129
|
+
if isinstance(value, type_):
|
130
|
+
return True
|
131
|
+
|
132
|
+
if isinstance(value, BaseDescriptor) and value.types:
|
133
|
+
return issubclass(value.types[0], type_)
|
134
|
+
|
135
|
+
return False
|
@@ -1,10 +1,27 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
2
3
|
|
3
4
|
from vellum.workflows.events.workflow import WorkflowEvent
|
4
5
|
from vellum.workflows.state.base import BaseState
|
5
6
|
|
7
|
+
# To protect against circular imports
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from vellum.workflows.state.context import WorkflowContext
|
10
|
+
|
6
11
|
|
7
12
|
class BaseWorkflowEmitter(ABC):
|
13
|
+
def __init__(self):
|
14
|
+
self._context: Optional["WorkflowContext"] = None
|
15
|
+
|
16
|
+
def register_context(self, context: "WorkflowContext") -> None:
|
17
|
+
"""
|
18
|
+
Register the workflow context with this emitter.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
context: The workflow context containing shared resources like vellum_client.
|
22
|
+
"""
|
23
|
+
self._context = context
|
24
|
+
|
8
25
|
@abstractmethod
|
9
26
|
def emit_event(self, event: WorkflowEvent) -> None:
|
10
27
|
pass
|
@@ -0,0 +1,138 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from typing import Any, Dict, Optional
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
|
7
|
+
from vellum.workflows.emitters.base import BaseWorkflowEmitter
|
8
|
+
from vellum.workflows.events.types import default_serializer
|
9
|
+
from vellum.workflows.events.workflow import WorkflowEvent
|
10
|
+
from vellum.workflows.state.base import BaseState
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class VellumEmitter(BaseWorkflowEmitter):
|
16
|
+
"""
|
17
|
+
Emitter that sends workflow events to Vellum's infrastructure for monitoring
|
18
|
+
externally hosted SDK-powered workflows.
|
19
|
+
|
20
|
+
Usage:
|
21
|
+
class MyWorkflow(BaseWorkflow):
|
22
|
+
emitters = [VellumEmitter]
|
23
|
+
|
24
|
+
The emitter will automatically use the same Vellum client configuration
|
25
|
+
as the workflow it's attached to.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
*,
|
31
|
+
timeout: Optional[float] = 30.0,
|
32
|
+
max_retries: int = 3,
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Initialize the VellumEmitter.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
timeout: Request timeout in seconds.
|
39
|
+
max_retries: Maximum number of retry attempts for failed requests.
|
40
|
+
"""
|
41
|
+
super().__init__()
|
42
|
+
self._timeout = timeout
|
43
|
+
self._max_retries = max_retries
|
44
|
+
self._events_endpoint = "events" # TODO: make this configurable with the correct url
|
45
|
+
|
46
|
+
def emit_event(self, event: WorkflowEvent) -> None:
|
47
|
+
"""
|
48
|
+
Emit a workflow event to Vellum's infrastructure.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
event: The workflow event to emit.
|
52
|
+
"""
|
53
|
+
if not self._context:
|
54
|
+
return
|
55
|
+
|
56
|
+
try:
|
57
|
+
event_data = default_serializer(event)
|
58
|
+
|
59
|
+
self._send_event(event_data)
|
60
|
+
|
61
|
+
except Exception as e:
|
62
|
+
logger.exception(f"Failed to emit event {event.name}: {e}")
|
63
|
+
|
64
|
+
def snapshot_state(self, state: BaseState) -> None:
|
65
|
+
"""
|
66
|
+
Send a state snapshot to Vellum's infrastructure.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
state: The workflow state to snapshot.
|
70
|
+
"""
|
71
|
+
pass
|
72
|
+
|
73
|
+
def _send_event(self, event_data: Dict[str, Any]) -> None:
|
74
|
+
"""
|
75
|
+
Send event data to Vellum's events endpoint with retry logic.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
event_data: The serialized event data to send.
|
79
|
+
"""
|
80
|
+
if not self._context:
|
81
|
+
logger.warning("Cannot send event: No workflow context registered")
|
82
|
+
return
|
83
|
+
|
84
|
+
client = self._context.vellum_client
|
85
|
+
|
86
|
+
for attempt in range(self._max_retries + 1):
|
87
|
+
try:
|
88
|
+
# Use the Vellum client's underlying HTTP client to make the request
|
89
|
+
# For proper authentication headers and configuration
|
90
|
+
base_url = client._client_wrapper.get_environment().default
|
91
|
+
response = client._client_wrapper.httpx_client.request(
|
92
|
+
method="POST",
|
93
|
+
path=f"{base_url}/{self._events_endpoint}", # TODO: will be replaced with the correct url
|
94
|
+
json=event_data,
|
95
|
+
headers=client._client_wrapper.get_headers(),
|
96
|
+
request_options={"timeout_in_seconds": self._timeout},
|
97
|
+
)
|
98
|
+
|
99
|
+
response.raise_for_status()
|
100
|
+
|
101
|
+
if attempt > 0:
|
102
|
+
logger.info(f"Event sent successfully after {attempt + 1} attempts")
|
103
|
+
return
|
104
|
+
|
105
|
+
except httpx.HTTPStatusError as e:
|
106
|
+
if e.response.status_code >= 500:
|
107
|
+
# Server errors might be transient, retry
|
108
|
+
if attempt < self._max_retries:
|
109
|
+
wait_time = min(2**attempt, 60) # Exponential backoff, max 60s
|
110
|
+
logger.warning(
|
111
|
+
f"Server error emitting event (attempt {attempt + 1}/{self._max_retries + 1}): "
|
112
|
+
f"{e.response.status_code}. Retrying in {wait_time}s..."
|
113
|
+
)
|
114
|
+
time.sleep(wait_time)
|
115
|
+
continue
|
116
|
+
else:
|
117
|
+
logger.exception(
|
118
|
+
f"Server error emitting event after {self._max_retries + 1} attempts: "
|
119
|
+
f"{e.response.status_code} {e.response.text}"
|
120
|
+
)
|
121
|
+
return
|
122
|
+
else:
|
123
|
+
# Client errors (4xx) are not retriable
|
124
|
+
logger.exception(f"Client error emitting event: {e.response.status_code} {e.response.text}")
|
125
|
+
return
|
126
|
+
|
127
|
+
except httpx.RequestError as e:
|
128
|
+
if attempt < self._max_retries:
|
129
|
+
wait_time = min(2**attempt, 60) # Exponential backoff, max 60s
|
130
|
+
logger.warning(
|
131
|
+
f"Network error emitting event (attempt {attempt + 1}/{self._max_retries + 1}): "
|
132
|
+
f"{e}. Retrying in {wait_time}s..."
|
133
|
+
)
|
134
|
+
time.sleep(wait_time)
|
135
|
+
continue
|
136
|
+
else:
|
137
|
+
logger.exception(f"Network error emitting event after {self._max_retries + 1} attempts: {e}")
|
138
|
+
return
|
@@ -7,10 +7,11 @@ from pydantic_core import core_schema
|
|
7
7
|
|
8
8
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
9
9
|
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
10
|
-
from vellum.workflows.descriptors.utils import resolve_value
|
10
|
+
from vellum.workflows.descriptors.utils import is_resolved_instance, resolve_value
|
11
11
|
from vellum.workflows.state.base import BaseState
|
12
12
|
|
13
13
|
LHS = TypeVar("LHS")
|
14
|
+
AccessorField = Union[str, int, BaseDescriptor[str], BaseDescriptor[int]]
|
14
15
|
|
15
16
|
|
16
17
|
class AccessorExpression(BaseDescriptor[Any]):
|
@@ -18,7 +19,7 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
18
19
|
self,
|
19
20
|
*,
|
20
21
|
base: BaseDescriptor[LHS],
|
21
|
-
field:
|
22
|
+
field: AccessorField,
|
22
23
|
) -> None:
|
23
24
|
super().__init__(
|
24
25
|
name=f"{base.name}.{field}",
|
@@ -28,7 +29,7 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
28
29
|
self._base = base
|
29
30
|
self._field = field
|
30
31
|
|
31
|
-
def _infer_accessor_types(self, base: BaseDescriptor[LHS], field:
|
32
|
+
def _infer_accessor_types(self, base: BaseDescriptor[LHS], field: AccessorField) -> tuple[Type, ...]:
|
32
33
|
"""
|
33
34
|
Infer the types for this accessor expression based on the base descriptor's types
|
34
35
|
and the field being accessed.
|
@@ -42,7 +43,7 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
42
43
|
origin = get_origin(base_type)
|
43
44
|
args = get_args(base_type)
|
44
45
|
|
45
|
-
if
|
46
|
+
if is_resolved_instance(field, int) and origin in (list, tuple) and args:
|
46
47
|
if origin is list:
|
47
48
|
inferred_types.append(args[0])
|
48
49
|
elif origin is tuple and len(args) == 2 and args[1] is ...:
|
@@ -52,44 +53,51 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
52
53
|
inferred_types.append(args[field])
|
53
54
|
else:
|
54
55
|
inferred_types.append(args[field])
|
55
|
-
elif
|
56
|
+
elif is_resolved_instance(field, str) and origin in (dict,) and len(args) >= 2:
|
56
57
|
inferred_types.append(args[1]) # Value type from Dict[K, V]
|
57
58
|
|
58
59
|
return tuple(set(inferred_types)) if inferred_types else ()
|
59
60
|
|
60
61
|
def resolve(self, state: "BaseState") -> Any:
|
61
62
|
base = resolve_value(self._base, state)
|
63
|
+
accessor_field = resolve_value(self._field, state)
|
62
64
|
|
63
65
|
if dataclasses.is_dataclass(base):
|
64
|
-
if isinstance(
|
66
|
+
if isinstance(accessor_field, int):
|
65
67
|
raise InvalidExpressionException("Cannot access field by index on a dataclass")
|
66
68
|
|
67
69
|
try:
|
68
|
-
return getattr(base,
|
70
|
+
return getattr(base, accessor_field)
|
69
71
|
except AttributeError:
|
70
|
-
raise InvalidExpressionException(
|
72
|
+
raise InvalidExpressionException(
|
73
|
+
f"Field '{accessor_field}' not found on dataclass {type(base).__name__}"
|
74
|
+
)
|
71
75
|
|
72
76
|
if isinstance(base, BaseModel):
|
73
|
-
if isinstance(
|
77
|
+
if isinstance(accessor_field, int):
|
74
78
|
raise InvalidExpressionException("Cannot access field by index on a BaseModel")
|
75
79
|
|
76
80
|
try:
|
77
|
-
return getattr(base,
|
81
|
+
return getattr(base, accessor_field)
|
78
82
|
except AttributeError:
|
79
|
-
raise InvalidExpressionException(
|
83
|
+
raise InvalidExpressionException(
|
84
|
+
f"Field '{accessor_field}' not found on BaseModel {type(base).__name__}"
|
85
|
+
)
|
80
86
|
|
81
87
|
if isinstance(base, Mapping):
|
82
88
|
try:
|
83
|
-
return base[
|
89
|
+
return base[accessor_field]
|
84
90
|
except KeyError:
|
85
|
-
raise InvalidExpressionException(f"Key '{
|
91
|
+
raise InvalidExpressionException(f"Key '{accessor_field}' not found in mapping")
|
86
92
|
|
87
93
|
if isinstance(base, Sequence):
|
88
94
|
try:
|
89
|
-
index = int(
|
95
|
+
index = int(accessor_field)
|
90
96
|
return base[index]
|
91
97
|
except (IndexError, ValueError):
|
92
|
-
if isinstance(
|
98
|
+
if isinstance(accessor_field, int) or (
|
99
|
+
isinstance(accessor_field, str) and accessor_field.lstrip("-").isdigit()
|
100
|
+
):
|
93
101
|
raise InvalidExpressionException(
|
94
102
|
f"Index {self._field} is out of bounds for sequence of length {len(base)}"
|
95
103
|
)
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Any, Generic, Protocol, TypeVar, Union, runtime_checkable
|
2
|
+
from typing_extensions import TypeGuard
|
3
|
+
|
4
|
+
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
6
|
+
from vellum.workflows.descriptors.utils import resolve_value
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
|
9
|
+
|
10
|
+
@runtime_checkable
|
11
|
+
class SupportsAdd(Protocol):
|
12
|
+
def __add__(self, other: Any) -> Any: ...
|
13
|
+
|
14
|
+
|
15
|
+
def has_add(obj: Any) -> TypeGuard[SupportsAdd]:
|
16
|
+
return hasattr(obj, "__add__")
|
17
|
+
|
18
|
+
|
19
|
+
LHS = TypeVar("LHS")
|
20
|
+
RHS = TypeVar("RHS")
|
21
|
+
|
22
|
+
|
23
|
+
class AddExpression(BaseDescriptor[Any], Generic[LHS, RHS]):
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
*,
|
27
|
+
lhs: Union[BaseDescriptor[LHS], LHS],
|
28
|
+
rhs: Union[BaseDescriptor[RHS], RHS],
|
29
|
+
) -> None:
|
30
|
+
super().__init__(name=f"{lhs} + {rhs}", types=(object,))
|
31
|
+
self._lhs = lhs
|
32
|
+
self._rhs = rhs
|
33
|
+
|
34
|
+
def resolve(self, state: "BaseState") -> Any:
|
35
|
+
lhs = resolve_value(self._lhs, state)
|
36
|
+
rhs = resolve_value(self._rhs, state)
|
37
|
+
|
38
|
+
if not has_add(lhs):
|
39
|
+
raise InvalidExpressionException(f"'{lhs.__class__.__name__}' must support the '+' operator")
|
40
|
+
|
41
|
+
return lhs + rhs
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from typing import Generic, TypeVar, Union
|
2
|
+
|
3
|
+
from vellum.workflows.constants import undefined
|
4
|
+
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
6
|
+
from vellum.workflows.descriptors.utils import resolve_value
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
|
9
|
+
_T = TypeVar("_T")
|
10
|
+
|
11
|
+
|
12
|
+
class LengthExpression(BaseDescriptor[int], Generic[_T]):
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
*,
|
16
|
+
expression: Union[BaseDescriptor[_T], _T],
|
17
|
+
) -> None:
|
18
|
+
super().__init__(name=f"length({expression})", types=(int,))
|
19
|
+
self._expression = expression
|
20
|
+
|
21
|
+
def resolve(self, state: "BaseState") -> int:
|
22
|
+
expression = resolve_value(self._expression, state)
|
23
|
+
|
24
|
+
if expression is undefined:
|
25
|
+
raise InvalidExpressionException("Cannot get length of undefined value")
|
26
|
+
|
27
|
+
if not hasattr(expression, "__len__"):
|
28
|
+
raise InvalidExpressionException(
|
29
|
+
f"Expected an object that supports `len()`, got `{expression.__class__.__name__}`"
|
30
|
+
)
|
31
|
+
|
32
|
+
try:
|
33
|
+
return len(expression)
|
34
|
+
except TypeError as e:
|
35
|
+
raise InvalidExpressionException(f"Cannot get length of `{expression.__class__.__name__}`: {str(e)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Any, Generic, Protocol, TypeVar, Union, runtime_checkable
|
2
|
+
from typing_extensions import TypeGuard
|
3
|
+
|
4
|
+
from vellum.workflows.descriptors.base import BaseDescriptor
|
5
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
6
|
+
from vellum.workflows.descriptors.utils import resolve_value
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
|
9
|
+
|
10
|
+
@runtime_checkable
|
11
|
+
class SupportsMinus(Protocol):
|
12
|
+
def __sub__(self, other: Any) -> Any: ...
|
13
|
+
|
14
|
+
|
15
|
+
def has_sub(obj: Any) -> TypeGuard[SupportsMinus]:
|
16
|
+
return hasattr(obj, "__sub__")
|
17
|
+
|
18
|
+
|
19
|
+
LHS = TypeVar("LHS")
|
20
|
+
RHS = TypeVar("RHS")
|
21
|
+
|
22
|
+
|
23
|
+
class MinusExpression(BaseDescriptor[Any], Generic[LHS, RHS]):
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
*,
|
27
|
+
lhs: Union[BaseDescriptor[LHS], LHS],
|
28
|
+
rhs: Union[BaseDescriptor[RHS], RHS],
|
29
|
+
) -> None:
|
30
|
+
super().__init__(name=f"{lhs} - {rhs}", types=(object,))
|
31
|
+
self._lhs = lhs
|
32
|
+
self._rhs = rhs
|
33
|
+
|
34
|
+
def resolve(self, state: "BaseState") -> Any:
|
35
|
+
lhs = resolve_value(self._lhs, state)
|
36
|
+
rhs = resolve_value(self._rhs, state)
|
37
|
+
|
38
|
+
if not has_sub(lhs):
|
39
|
+
raise InvalidExpressionException(f"'{lhs.__class__.__name__}' must support the '-' operator")
|
40
|
+
|
41
|
+
return lhs - rhs
|
@@ -0,0 +1,72 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
4
|
+
from vellum.workflows.expressions.add import AddExpression
|
5
|
+
from vellum.workflows.state.base import BaseState
|
6
|
+
|
7
|
+
|
8
|
+
class TestState(BaseState):
|
9
|
+
number_value: int = 5
|
10
|
+
string_value: str = "hello"
|
11
|
+
|
12
|
+
|
13
|
+
def test_add_expression_numbers():
|
14
|
+
"""
|
15
|
+
Tests that AddExpression correctly adds two numbers.
|
16
|
+
"""
|
17
|
+
|
18
|
+
state = TestState()
|
19
|
+
|
20
|
+
expression = TestState.number_value.add(10)
|
21
|
+
|
22
|
+
result = expression.resolve(state)
|
23
|
+
assert result == 15
|
24
|
+
|
25
|
+
|
26
|
+
def test_add_expression_strings():
|
27
|
+
"""
|
28
|
+
Tests that AddExpression correctly concatenates two strings.
|
29
|
+
"""
|
30
|
+
|
31
|
+
state = TestState()
|
32
|
+
|
33
|
+
expression = TestState.string_value.add(" world")
|
34
|
+
|
35
|
+
result = expression.resolve(state)
|
36
|
+
assert result == "hello world"
|
37
|
+
|
38
|
+
|
39
|
+
def test_add_expression_unsupported_type():
|
40
|
+
"""
|
41
|
+
Tests that AddExpression raises an exception for unsupported types.
|
42
|
+
"""
|
43
|
+
|
44
|
+
class NoAddSupport:
|
45
|
+
pass
|
46
|
+
|
47
|
+
no_add_obj = NoAddSupport()
|
48
|
+
expression = AddExpression(lhs=no_add_obj, rhs=5)
|
49
|
+
state = TestState()
|
50
|
+
|
51
|
+
with pytest.raises(InvalidExpressionException, match="'NoAddSupport' must support the '\\+' operator"):
|
52
|
+
expression.resolve(state)
|
53
|
+
|
54
|
+
|
55
|
+
def test_add_expression_name():
|
56
|
+
"""
|
57
|
+
Tests that AddExpression has the correct name.
|
58
|
+
"""
|
59
|
+
|
60
|
+
expression = AddExpression(lhs=5, rhs=3)
|
61
|
+
|
62
|
+
assert expression.name == "5 + 3"
|
63
|
+
|
64
|
+
|
65
|
+
def test_add_expression_types():
|
66
|
+
"""
|
67
|
+
Tests that AddExpression has the correct types.
|
68
|
+
"""
|
69
|
+
|
70
|
+
expression = AddExpression(lhs=5, rhs=3)
|
71
|
+
|
72
|
+
assert expression.types == (object,)
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.constants import undefined
|
4
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
5
|
+
from vellum.workflows.expressions.length import LengthExpression
|
6
|
+
from vellum.workflows.state.base import BaseState
|
7
|
+
|
8
|
+
|
9
|
+
class TestState(BaseState):
|
10
|
+
string_value: str = "hello world"
|
11
|
+
|
12
|
+
|
13
|
+
def test_length_expression_string():
|
14
|
+
"""
|
15
|
+
Tests that LengthExpression correctly returns the length of a string.
|
16
|
+
"""
|
17
|
+
|
18
|
+
state = TestState()
|
19
|
+
|
20
|
+
expression = TestState.string_value.length()
|
21
|
+
result = expression.resolve(state)
|
22
|
+
|
23
|
+
assert result == 11
|
24
|
+
|
25
|
+
|
26
|
+
def test_length_expression_undefined():
|
27
|
+
"""
|
28
|
+
Tests that LengthExpression raises an exception for undefined values.
|
29
|
+
"""
|
30
|
+
|
31
|
+
expression = LengthExpression(expression=undefined)
|
32
|
+
state = TestState()
|
33
|
+
|
34
|
+
# THEN we should get an InvalidExpressionException
|
35
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
36
|
+
expression.resolve(state)
|
37
|
+
|
38
|
+
assert "Cannot get length of undefined value" in str(exc_info.value)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
4
|
+
from vellum.workflows.expressions.minus import MinusExpression
|
5
|
+
from vellum.workflows.state.base import BaseState
|
6
|
+
|
7
|
+
|
8
|
+
class TestState(BaseState):
|
9
|
+
number_value: int = 10
|
10
|
+
float_value: float = 15.5
|
11
|
+
|
12
|
+
|
13
|
+
def test_minus_expression_numbers():
|
14
|
+
"""
|
15
|
+
Tests that MinusExpression correctly subtracts two numbers.
|
16
|
+
"""
|
17
|
+
|
18
|
+
state = TestState()
|
19
|
+
|
20
|
+
expression = TestState.number_value.minus(3)
|
21
|
+
|
22
|
+
result = expression.resolve(state)
|
23
|
+
assert result == 7
|
24
|
+
|
25
|
+
|
26
|
+
def test_minus_expression_floats():
|
27
|
+
"""
|
28
|
+
Tests that MinusExpression correctly subtracts two floats.
|
29
|
+
"""
|
30
|
+
|
31
|
+
state = TestState()
|
32
|
+
|
33
|
+
expression = TestState.float_value.minus(5.5)
|
34
|
+
|
35
|
+
result = expression.resolve(state)
|
36
|
+
assert result == 10.0
|
37
|
+
|
38
|
+
|
39
|
+
def test_minus_expression_unsupported_type():
|
40
|
+
"""
|
41
|
+
Tests that MinusExpression raises an exception for unsupported types.
|
42
|
+
"""
|
43
|
+
|
44
|
+
class NoSubSupport:
|
45
|
+
pass
|
46
|
+
|
47
|
+
no_sub_obj = NoSubSupport()
|
48
|
+
expression = MinusExpression(lhs=no_sub_obj, rhs=5)
|
49
|
+
state = TestState()
|
50
|
+
|
51
|
+
with pytest.raises(InvalidExpressionException, match="'NoSubSupport' must support the '-' operator"):
|
52
|
+
expression.resolve(state)
|
53
|
+
|
54
|
+
|
55
|
+
def test_minus_expression_name():
|
56
|
+
"""
|
57
|
+
Tests that MinusExpression has the correct name.
|
58
|
+
"""
|
59
|
+
|
60
|
+
expression = MinusExpression(lhs=10, rhs=3)
|
61
|
+
|
62
|
+
assert expression.name == "10 - 3"
|
63
|
+
|
64
|
+
|
65
|
+
def test_minus_expression_types():
|
66
|
+
"""
|
67
|
+
Tests that MinusExpression has the correct types.
|
68
|
+
"""
|
69
|
+
|
70
|
+
expression = MinusExpression(lhs=10, rhs=3)
|
71
|
+
|
72
|
+
assert expression.types == (object,)
|