vellum-ai 1.0.9__py3-none-any.whl → 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. vellum/client/core/client_wrapper.py +2 -2
  2. vellum/workflows/descriptors/base.py +31 -1
  3. vellum/workflows/descriptors/utils.py +19 -1
  4. vellum/workflows/emitters/__init__.py +2 -0
  5. vellum/workflows/emitters/base.py +17 -0
  6. vellum/workflows/emitters/vellum_emitter.py +138 -0
  7. vellum/workflows/expressions/accessor.py +23 -15
  8. vellum/workflows/expressions/add.py +41 -0
  9. vellum/workflows/expressions/length.py +35 -0
  10. vellum/workflows/expressions/minus.py +41 -0
  11. vellum/workflows/expressions/tests/test_add.py +72 -0
  12. vellum/workflows/expressions/tests/test_length.py +38 -0
  13. vellum/workflows/expressions/tests/test_minus.py +72 -0
  14. vellum/workflows/integrations/composio_service.py +10 -2
  15. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
  16. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +2 -2
  17. vellum/workflows/nodes/displayable/tool_calling_node/node.py +24 -20
  18. vellum/workflows/nodes/displayable/tool_calling_node/state.py +2 -0
  19. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_composio_service.py +92 -0
  20. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_node.py +25 -10
  21. vellum/workflows/nodes/displayable/tool_calling_node/tests/test_utils.py +7 -5
  22. vellum/workflows/nodes/displayable/tool_calling_node/utils.py +141 -86
  23. vellum/workflows/types/core.py +3 -5
  24. vellum/workflows/types/definition.py +2 -6
  25. vellum/workflows/types/tests/test_definition.py +5 -2
  26. {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/METADATA +1 -1
  27. {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/RECORD +34 -27
  28. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_composio_serialization.py +1 -4
  29. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_inline_workflow_serialization.py +0 -5
  30. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +0 -5
  31. vellum_ee/workflows/display/utils/expressions.py +12 -0
  32. {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/LICENSE +0 -0
  33. {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/WHEEL +0 -0
  34. {vellum_ai-1.0.9.dist-info → vellum_ai-1.0.11.dist-info}/entry_points.txt +0 -0
@@ -25,10 +25,10 @@ class BaseClientWrapper:
25
25
 
26
26
  def get_headers(self) -> typing.Dict[str, str]:
27
27
  headers: typing.Dict[str, str] = {
28
- "User-Agent": "vellum-ai/1.0.9",
28
+ "User-Agent": "vellum-ai/1.0.11",
29
29
  "X-Fern-Language": "Python",
30
30
  "X-Fern-SDK-Name": "vellum-ai",
31
- "X-Fern-SDK-Version": "1.0.9",
31
+ "X-Fern-SDK-Version": "1.0.11",
32
32
  }
33
33
  if self._api_version is not None:
34
34
  headers["X-API-Version"] = self._api_version
@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, Tuple, Type, TypeVar,
2
2
 
3
3
  if TYPE_CHECKING:
4
4
  from vellum.workflows.expressions.accessor import AccessorExpression
5
+ from vellum.workflows.expressions.add import AddExpression
5
6
  from vellum.workflows.expressions.and_ import AndExpression
6
7
  from vellum.workflows.expressions.begins_with import BeginsWithExpression
7
8
  from vellum.workflows.expressions.between import BetweenExpression
@@ -26,8 +27,10 @@ if TYPE_CHECKING:
26
27
  from vellum.workflows.expressions.is_not_undefined import IsNotUndefinedExpression
27
28
  from vellum.workflows.expressions.is_null import IsNullExpression
28
29
  from vellum.workflows.expressions.is_undefined import IsUndefinedExpression
30
+ from vellum.workflows.expressions.length import LengthExpression
29
31
  from vellum.workflows.expressions.less_than import LessThanExpression
30
32
  from vellum.workflows.expressions.less_than_or_equal_to import LessThanOrEqualToExpression
33
+ from vellum.workflows.expressions.minus import MinusExpression
31
34
  from vellum.workflows.expressions.not_between import NotBetweenExpression
32
35
  from vellum.workflows.expressions.not_in import NotInExpression
33
36
  from vellum.workflows.expressions.or_ import OrExpression
@@ -127,7 +130,7 @@ class BaseDescriptor(Generic[_T]):
127
130
 
128
131
  return CoalesceExpression(lhs=self, rhs=other)
129
132
 
130
- def __getitem__(self, field: Union[str, int]) -> "AccessorExpression":
133
+ def __getitem__(self, field: Union[str, int, "BaseDescriptor[str]", "BaseDescriptor[int]"]) -> "AccessorExpression":
131
134
  from vellum.workflows.expressions.accessor import AccessorExpression
132
135
 
133
136
  return AccessorExpression(base=self, field=field)
@@ -376,3 +379,30 @@ class BaseDescriptor(Generic[_T]):
376
379
  from vellum.workflows.expressions.concat import ConcatExpression
377
380
 
378
381
  return ConcatExpression(lhs=self, rhs=other)
382
+
383
+ def length(self) -> "LengthExpression[_T]":
384
+ from vellum.workflows.expressions.length import LengthExpression
385
+
386
+ return LengthExpression(expression=self)
387
+
388
+ @overload
389
+ def add(self, other: "BaseDescriptor[_O]") -> "AddExpression[_T, _O]": ...
390
+
391
+ @overload
392
+ def add(self, other: _O) -> "AddExpression[_T, _O]": ...
393
+
394
+ def add(self, other: "Union[BaseDescriptor[_O], _O]") -> "AddExpression[_T, _O]":
395
+ from vellum.workflows.expressions.add import AddExpression
396
+
397
+ return AddExpression(lhs=self, rhs=other)
398
+
399
+ @overload
400
+ def minus(self, other: "BaseDescriptor[_O]") -> "MinusExpression[_T, _O]": ...
401
+
402
+ @overload
403
+ def minus(self, other: _O) -> "MinusExpression[_T, _O]": ...
404
+
405
+ def minus(self, other: "Union[BaseDescriptor[_O], _O]") -> "MinusExpression[_T, _O]":
406
+ from vellum.workflows.expressions.minus import MinusExpression
407
+
408
+ return MinusExpression(lhs=self, rhs=other)
@@ -1,7 +1,8 @@
1
1
  from collections.abc import Mapping
2
2
  import dataclasses
3
3
  import inspect
4
- from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, overload
4
+ from typing import Any, Dict, Optional, Sequence, Set, Type, TypeVar, Union, cast, overload
5
+ from typing_extensions import TypeGuard
5
6
 
6
7
  from pydantic import BaseModel
7
8
 
@@ -115,3 +116,20 @@ def is_unresolved(value: Any) -> bool:
115
116
  return any(is_unresolved(item) for item in value)
116
117
 
117
118
  return False
119
+
120
+
121
+ _ResolvedType = TypeVar("_ResolvedType")
122
+
123
+
124
+ def is_resolved_instance(value: Any, type_: Type[_ResolvedType]) -> TypeGuard[_ResolvedType]:
125
+ """
126
+ Checks if a value is an instance of a type or a descriptor that resolves to that type.
127
+ """
128
+
129
+ if isinstance(value, type_):
130
+ return True
131
+
132
+ if isinstance(value, BaseDescriptor) and value.types:
133
+ return issubclass(value.types[0], type_)
134
+
135
+ return False
@@ -1,5 +1,7 @@
1
1
  from .base import BaseWorkflowEmitter
2
+ from .vellum_emitter import VellumEmitter
2
3
 
3
4
  __all__ = [
4
5
  "BaseWorkflowEmitter",
6
+ "VellumEmitter",
5
7
  ]
@@ -1,10 +1,27 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Optional
2
3
 
3
4
  from vellum.workflows.events.workflow import WorkflowEvent
4
5
  from vellum.workflows.state.base import BaseState
5
6
 
7
+ # To protect against circular imports
8
+ if TYPE_CHECKING:
9
+ from vellum.workflows.state.context import WorkflowContext
10
+
6
11
 
7
12
  class BaseWorkflowEmitter(ABC):
13
+ def __init__(self):
14
+ self._context: Optional["WorkflowContext"] = None
15
+
16
+ def register_context(self, context: "WorkflowContext") -> None:
17
+ """
18
+ Register the workflow context with this emitter.
19
+
20
+ Args:
21
+ context: The workflow context containing shared resources like vellum_client.
22
+ """
23
+ self._context = context
24
+
8
25
  @abstractmethod
9
26
  def emit_event(self, event: WorkflowEvent) -> None:
10
27
  pass
@@ -0,0 +1,138 @@
1
+ import logging
2
+ import time
3
+ from typing import Any, Dict, Optional
4
+
5
+ import httpx
6
+
7
+ from vellum.workflows.emitters.base import BaseWorkflowEmitter
8
+ from vellum.workflows.events.types import default_serializer
9
+ from vellum.workflows.events.workflow import WorkflowEvent
10
+ from vellum.workflows.state.base import BaseState
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class VellumEmitter(BaseWorkflowEmitter):
16
+ """
17
+ Emitter that sends workflow events to Vellum's infrastructure for monitoring
18
+ externally hosted SDK-powered workflows.
19
+
20
+ Usage:
21
+ class MyWorkflow(BaseWorkflow):
22
+ emitters = [VellumEmitter]
23
+
24
+ The emitter will automatically use the same Vellum client configuration
25
+ as the workflow it's attached to.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ *,
31
+ timeout: Optional[float] = 30.0,
32
+ max_retries: int = 3,
33
+ ):
34
+ """
35
+ Initialize the VellumEmitter.
36
+
37
+ Args:
38
+ timeout: Request timeout in seconds.
39
+ max_retries: Maximum number of retry attempts for failed requests.
40
+ """
41
+ super().__init__()
42
+ self._timeout = timeout
43
+ self._max_retries = max_retries
44
+ self._events_endpoint = "events" # TODO: make this configurable with the correct url
45
+
46
+ def emit_event(self, event: WorkflowEvent) -> None:
47
+ """
48
+ Emit a workflow event to Vellum's infrastructure.
49
+
50
+ Args:
51
+ event: The workflow event to emit.
52
+ """
53
+ if not self._context:
54
+ return
55
+
56
+ try:
57
+ event_data = default_serializer(event)
58
+
59
+ self._send_event(event_data)
60
+
61
+ except Exception as e:
62
+ logger.exception(f"Failed to emit event {event.name}: {e}")
63
+
64
+ def snapshot_state(self, state: BaseState) -> None:
65
+ """
66
+ Send a state snapshot to Vellum's infrastructure.
67
+
68
+ Args:
69
+ state: The workflow state to snapshot.
70
+ """
71
+ pass
72
+
73
+ def _send_event(self, event_data: Dict[str, Any]) -> None:
74
+ """
75
+ Send event data to Vellum's events endpoint with retry logic.
76
+
77
+ Args:
78
+ event_data: The serialized event data to send.
79
+ """
80
+ if not self._context:
81
+ logger.warning("Cannot send event: No workflow context registered")
82
+ return
83
+
84
+ client = self._context.vellum_client
85
+
86
+ for attempt in range(self._max_retries + 1):
87
+ try:
88
+ # Use the Vellum client's underlying HTTP client to make the request
89
+ # For proper authentication headers and configuration
90
+ base_url = client._client_wrapper.get_environment().default
91
+ response = client._client_wrapper.httpx_client.request(
92
+ method="POST",
93
+ path=f"{base_url}/{self._events_endpoint}", # TODO: will be replaced with the correct url
94
+ json=event_data,
95
+ headers=client._client_wrapper.get_headers(),
96
+ request_options={"timeout_in_seconds": self._timeout},
97
+ )
98
+
99
+ response.raise_for_status()
100
+
101
+ if attempt > 0:
102
+ logger.info(f"Event sent successfully after {attempt + 1} attempts")
103
+ return
104
+
105
+ except httpx.HTTPStatusError as e:
106
+ if e.response.status_code >= 500:
107
+ # Server errors might be transient, retry
108
+ if attempt < self._max_retries:
109
+ wait_time = min(2**attempt, 60) # Exponential backoff, max 60s
110
+ logger.warning(
111
+ f"Server error emitting event (attempt {attempt + 1}/{self._max_retries + 1}): "
112
+ f"{e.response.status_code}. Retrying in {wait_time}s..."
113
+ )
114
+ time.sleep(wait_time)
115
+ continue
116
+ else:
117
+ logger.exception(
118
+ f"Server error emitting event after {self._max_retries + 1} attempts: "
119
+ f"{e.response.status_code} {e.response.text}"
120
+ )
121
+ return
122
+ else:
123
+ # Client errors (4xx) are not retriable
124
+ logger.exception(f"Client error emitting event: {e.response.status_code} {e.response.text}")
125
+ return
126
+
127
+ except httpx.RequestError as e:
128
+ if attempt < self._max_retries:
129
+ wait_time = min(2**attempt, 60) # Exponential backoff, max 60s
130
+ logger.warning(
131
+ f"Network error emitting event (attempt {attempt + 1}/{self._max_retries + 1}): "
132
+ f"{e}. Retrying in {wait_time}s..."
133
+ )
134
+ time.sleep(wait_time)
135
+ continue
136
+ else:
137
+ logger.exception(f"Network error emitting event after {self._max_retries + 1} attempts: {e}")
138
+ return
@@ -7,10 +7,11 @@ from pydantic_core import core_schema
7
7
 
8
8
  from vellum.workflows.descriptors.base import BaseDescriptor
9
9
  from vellum.workflows.descriptors.exceptions import InvalidExpressionException
10
- from vellum.workflows.descriptors.utils import resolve_value
10
+ from vellum.workflows.descriptors.utils import is_resolved_instance, resolve_value
11
11
  from vellum.workflows.state.base import BaseState
12
12
 
13
13
  LHS = TypeVar("LHS")
14
+ AccessorField = Union[str, int, BaseDescriptor[str], BaseDescriptor[int]]
14
15
 
15
16
 
16
17
  class AccessorExpression(BaseDescriptor[Any]):
@@ -18,7 +19,7 @@ class AccessorExpression(BaseDescriptor[Any]):
18
19
  self,
19
20
  *,
20
21
  base: BaseDescriptor[LHS],
21
- field: Union[str, int],
22
+ field: AccessorField,
22
23
  ) -> None:
23
24
  super().__init__(
24
25
  name=f"{base.name}.{field}",
@@ -28,7 +29,7 @@ class AccessorExpression(BaseDescriptor[Any]):
28
29
  self._base = base
29
30
  self._field = field
30
31
 
31
- def _infer_accessor_types(self, base: BaseDescriptor[LHS], field: Union[str, int]) -> tuple[Type, ...]:
32
+ def _infer_accessor_types(self, base: BaseDescriptor[LHS], field: AccessorField) -> tuple[Type, ...]:
32
33
  """
33
34
  Infer the types for this accessor expression based on the base descriptor's types
34
35
  and the field being accessed.
@@ -42,7 +43,7 @@ class AccessorExpression(BaseDescriptor[Any]):
42
43
  origin = get_origin(base_type)
43
44
  args = get_args(base_type)
44
45
 
45
- if isinstance(field, int) and origin in (list, tuple) and args:
46
+ if is_resolved_instance(field, int) and origin in (list, tuple) and args:
46
47
  if origin is list:
47
48
  inferred_types.append(args[0])
48
49
  elif origin is tuple and len(args) == 2 and args[1] is ...:
@@ -52,44 +53,51 @@ class AccessorExpression(BaseDescriptor[Any]):
52
53
  inferred_types.append(args[field])
53
54
  else:
54
55
  inferred_types.append(args[field])
55
- elif isinstance(field, str) and origin in (dict,) and len(args) >= 2:
56
+ elif is_resolved_instance(field, str) and origin in (dict,) and len(args) >= 2:
56
57
  inferred_types.append(args[1]) # Value type from Dict[K, V]
57
58
 
58
59
  return tuple(set(inferred_types)) if inferred_types else ()
59
60
 
60
61
  def resolve(self, state: "BaseState") -> Any:
61
62
  base = resolve_value(self._base, state)
63
+ accessor_field = resolve_value(self._field, state)
62
64
 
63
65
  if dataclasses.is_dataclass(base):
64
- if isinstance(self._field, int):
66
+ if isinstance(accessor_field, int):
65
67
  raise InvalidExpressionException("Cannot access field by index on a dataclass")
66
68
 
67
69
  try:
68
- return getattr(base, self._field)
70
+ return getattr(base, accessor_field)
69
71
  except AttributeError:
70
- raise InvalidExpressionException(f"Field '{self._field}' not found on dataclass {type(base).__name__}")
72
+ raise InvalidExpressionException(
73
+ f"Field '{accessor_field}' not found on dataclass {type(base).__name__}"
74
+ )
71
75
 
72
76
  if isinstance(base, BaseModel):
73
- if isinstance(self._field, int):
77
+ if isinstance(accessor_field, int):
74
78
  raise InvalidExpressionException("Cannot access field by index on a BaseModel")
75
79
 
76
80
  try:
77
- return getattr(base, self._field)
81
+ return getattr(base, accessor_field)
78
82
  except AttributeError:
79
- raise InvalidExpressionException(f"Field '{self._field}' not found on BaseModel {type(base).__name__}")
83
+ raise InvalidExpressionException(
84
+ f"Field '{accessor_field}' not found on BaseModel {type(base).__name__}"
85
+ )
80
86
 
81
87
  if isinstance(base, Mapping):
82
88
  try:
83
- return base[self._field]
89
+ return base[accessor_field]
84
90
  except KeyError:
85
- raise InvalidExpressionException(f"Key '{self._field}' not found in mapping")
91
+ raise InvalidExpressionException(f"Key '{accessor_field}' not found in mapping")
86
92
 
87
93
  if isinstance(base, Sequence):
88
94
  try:
89
- index = int(self._field)
95
+ index = int(accessor_field)
90
96
  return base[index]
91
97
  except (IndexError, ValueError):
92
- if isinstance(self._field, int) or (isinstance(self._field, str) and self._field.lstrip("-").isdigit()):
98
+ if isinstance(accessor_field, int) or (
99
+ isinstance(accessor_field, str) and accessor_field.lstrip("-").isdigit()
100
+ ):
93
101
  raise InvalidExpressionException(
94
102
  f"Index {self._field} is out of bounds for sequence of length {len(base)}"
95
103
  )
@@ -0,0 +1,41 @@
1
+ from typing import Any, Generic, Protocol, TypeVar, Union, runtime_checkable
2
+ from typing_extensions import TypeGuard
3
+
4
+ from vellum.workflows.descriptors.base import BaseDescriptor
5
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
6
+ from vellum.workflows.descriptors.utils import resolve_value
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+
10
+ @runtime_checkable
11
+ class SupportsAdd(Protocol):
12
+ def __add__(self, other: Any) -> Any: ...
13
+
14
+
15
+ def has_add(obj: Any) -> TypeGuard[SupportsAdd]:
16
+ return hasattr(obj, "__add__")
17
+
18
+
19
+ LHS = TypeVar("LHS")
20
+ RHS = TypeVar("RHS")
21
+
22
+
23
+ class AddExpression(BaseDescriptor[Any], Generic[LHS, RHS]):
24
+ def __init__(
25
+ self,
26
+ *,
27
+ lhs: Union[BaseDescriptor[LHS], LHS],
28
+ rhs: Union[BaseDescriptor[RHS], RHS],
29
+ ) -> None:
30
+ super().__init__(name=f"{lhs} + {rhs}", types=(object,))
31
+ self._lhs = lhs
32
+ self._rhs = rhs
33
+
34
+ def resolve(self, state: "BaseState") -> Any:
35
+ lhs = resolve_value(self._lhs, state)
36
+ rhs = resolve_value(self._rhs, state)
37
+
38
+ if not has_add(lhs):
39
+ raise InvalidExpressionException(f"'{lhs.__class__.__name__}' must support the '+' operator")
40
+
41
+ return lhs + rhs
@@ -0,0 +1,35 @@
1
+ from typing import Generic, TypeVar, Union
2
+
3
+ from vellum.workflows.constants import undefined
4
+ from vellum.workflows.descriptors.base import BaseDescriptor
5
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
6
+ from vellum.workflows.descriptors.utils import resolve_value
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+ _T = TypeVar("_T")
10
+
11
+
12
+ class LengthExpression(BaseDescriptor[int], Generic[_T]):
13
+ def __init__(
14
+ self,
15
+ *,
16
+ expression: Union[BaseDescriptor[_T], _T],
17
+ ) -> None:
18
+ super().__init__(name=f"length({expression})", types=(int,))
19
+ self._expression = expression
20
+
21
+ def resolve(self, state: "BaseState") -> int:
22
+ expression = resolve_value(self._expression, state)
23
+
24
+ if expression is undefined:
25
+ raise InvalidExpressionException("Cannot get length of undefined value")
26
+
27
+ if not hasattr(expression, "__len__"):
28
+ raise InvalidExpressionException(
29
+ f"Expected an object that supports `len()`, got `{expression.__class__.__name__}`"
30
+ )
31
+
32
+ try:
33
+ return len(expression)
34
+ except TypeError as e:
35
+ raise InvalidExpressionException(f"Cannot get length of `{expression.__class__.__name__}`: {str(e)}")
@@ -0,0 +1,41 @@
1
+ from typing import Any, Generic, Protocol, TypeVar, Union, runtime_checkable
2
+ from typing_extensions import TypeGuard
3
+
4
+ from vellum.workflows.descriptors.base import BaseDescriptor
5
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
6
+ from vellum.workflows.descriptors.utils import resolve_value
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+
10
+ @runtime_checkable
11
+ class SupportsMinus(Protocol):
12
+ def __sub__(self, other: Any) -> Any: ...
13
+
14
+
15
+ def has_sub(obj: Any) -> TypeGuard[SupportsMinus]:
16
+ return hasattr(obj, "__sub__")
17
+
18
+
19
+ LHS = TypeVar("LHS")
20
+ RHS = TypeVar("RHS")
21
+
22
+
23
+ class MinusExpression(BaseDescriptor[Any], Generic[LHS, RHS]):
24
+ def __init__(
25
+ self,
26
+ *,
27
+ lhs: Union[BaseDescriptor[LHS], LHS],
28
+ rhs: Union[BaseDescriptor[RHS], RHS],
29
+ ) -> None:
30
+ super().__init__(name=f"{lhs} - {rhs}", types=(object,))
31
+ self._lhs = lhs
32
+ self._rhs = rhs
33
+
34
+ def resolve(self, state: "BaseState") -> Any:
35
+ lhs = resolve_value(self._lhs, state)
36
+ rhs = resolve_value(self._rhs, state)
37
+
38
+ if not has_sub(lhs):
39
+ raise InvalidExpressionException(f"'{lhs.__class__.__name__}' must support the '-' operator")
40
+
41
+ return lhs - rhs
@@ -0,0 +1,72 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
4
+ from vellum.workflows.expressions.add import AddExpression
5
+ from vellum.workflows.state.base import BaseState
6
+
7
+
8
+ class TestState(BaseState):
9
+ number_value: int = 5
10
+ string_value: str = "hello"
11
+
12
+
13
+ def test_add_expression_numbers():
14
+ """
15
+ Tests that AddExpression correctly adds two numbers.
16
+ """
17
+
18
+ state = TestState()
19
+
20
+ expression = TestState.number_value.add(10)
21
+
22
+ result = expression.resolve(state)
23
+ assert result == 15
24
+
25
+
26
+ def test_add_expression_strings():
27
+ """
28
+ Tests that AddExpression correctly concatenates two strings.
29
+ """
30
+
31
+ state = TestState()
32
+
33
+ expression = TestState.string_value.add(" world")
34
+
35
+ result = expression.resolve(state)
36
+ assert result == "hello world"
37
+
38
+
39
+ def test_add_expression_unsupported_type():
40
+ """
41
+ Tests that AddExpression raises an exception for unsupported types.
42
+ """
43
+
44
+ class NoAddSupport:
45
+ pass
46
+
47
+ no_add_obj = NoAddSupport()
48
+ expression = AddExpression(lhs=no_add_obj, rhs=5)
49
+ state = TestState()
50
+
51
+ with pytest.raises(InvalidExpressionException, match="'NoAddSupport' must support the '\\+' operator"):
52
+ expression.resolve(state)
53
+
54
+
55
+ def test_add_expression_name():
56
+ """
57
+ Tests that AddExpression has the correct name.
58
+ """
59
+
60
+ expression = AddExpression(lhs=5, rhs=3)
61
+
62
+ assert expression.name == "5 + 3"
63
+
64
+
65
+ def test_add_expression_types():
66
+ """
67
+ Tests that AddExpression has the correct types.
68
+ """
69
+
70
+ expression = AddExpression(lhs=5, rhs=3)
71
+
72
+ assert expression.types == (object,)
@@ -0,0 +1,38 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.constants import undefined
4
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
5
+ from vellum.workflows.expressions.length import LengthExpression
6
+ from vellum.workflows.state.base import BaseState
7
+
8
+
9
+ class TestState(BaseState):
10
+ string_value: str = "hello world"
11
+
12
+
13
+ def test_length_expression_string():
14
+ """
15
+ Tests that LengthExpression correctly returns the length of a string.
16
+ """
17
+
18
+ state = TestState()
19
+
20
+ expression = TestState.string_value.length()
21
+ result = expression.resolve(state)
22
+
23
+ assert result == 11
24
+
25
+
26
+ def test_length_expression_undefined():
27
+ """
28
+ Tests that LengthExpression raises an exception for undefined values.
29
+ """
30
+
31
+ expression = LengthExpression(expression=undefined)
32
+ state = TestState()
33
+
34
+ # THEN we should get an InvalidExpressionException
35
+ with pytest.raises(InvalidExpressionException) as exc_info:
36
+ expression.resolve(state)
37
+
38
+ assert "Cannot get length of undefined value" in str(exc_info.value)
@@ -0,0 +1,72 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
4
+ from vellum.workflows.expressions.minus import MinusExpression
5
+ from vellum.workflows.state.base import BaseState
6
+
7
+
8
+ class TestState(BaseState):
9
+ number_value: int = 10
10
+ float_value: float = 15.5
11
+
12
+
13
+ def test_minus_expression_numbers():
14
+ """
15
+ Tests that MinusExpression correctly subtracts two numbers.
16
+ """
17
+
18
+ state = TestState()
19
+
20
+ expression = TestState.number_value.minus(3)
21
+
22
+ result = expression.resolve(state)
23
+ assert result == 7
24
+
25
+
26
+ def test_minus_expression_floats():
27
+ """
28
+ Tests that MinusExpression correctly subtracts two floats.
29
+ """
30
+
31
+ state = TestState()
32
+
33
+ expression = TestState.float_value.minus(5.5)
34
+
35
+ result = expression.resolve(state)
36
+ assert result == 10.0
37
+
38
+
39
+ def test_minus_expression_unsupported_type():
40
+ """
41
+ Tests that MinusExpression raises an exception for unsupported types.
42
+ """
43
+
44
+ class NoSubSupport:
45
+ pass
46
+
47
+ no_sub_obj = NoSubSupport()
48
+ expression = MinusExpression(lhs=no_sub_obj, rhs=5)
49
+ state = TestState()
50
+
51
+ with pytest.raises(InvalidExpressionException, match="'NoSubSupport' must support the '-' operator"):
52
+ expression.resolve(state)
53
+
54
+
55
+ def test_minus_expression_name():
56
+ """
57
+ Tests that MinusExpression has the correct name.
58
+ """
59
+
60
+ expression = MinusExpression(lhs=10, rhs=3)
61
+
62
+ assert expression.name == "10 - 3"
63
+
64
+
65
+ def test_minus_expression_types():
66
+ """
67
+ Tests that MinusExpression has the correct types.
68
+ """
69
+
70
+ expression = MinusExpression(lhs=10, rhs=3)
71
+
72
+ assert expression.types == (object,)