vellum-ai 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. vellum/__init__.py +40 -0
  2. vellum/client/core/client_wrapper.py +2 -2
  3. vellum/client/core/pydantic_utilities.py +3 -2
  4. vellum/client/reference.md +16 -0
  5. vellum/client/resources/workflow_executions/client.py +28 -4
  6. vellum/client/resources/workflow_executions/raw_client.py +32 -2
  7. vellum/client/types/__init__.py +40 -0
  8. vellum/client/types/audio_input_request.py +30 -0
  9. vellum/client/types/delimiter_chunker_config.py +20 -0
  10. vellum/client/types/delimiter_chunker_config_request.py +20 -0
  11. vellum/client/types/delimiter_chunking.py +21 -0
  12. vellum/client/types/delimiter_chunking_request.py +21 -0
  13. vellum/client/types/document_index_chunking.py +4 -1
  14. vellum/client/types/document_index_chunking_request.py +2 -1
  15. vellum/client/types/document_input_request.py +30 -0
  16. vellum/client/types/execution_audio_vellum_value.py +31 -0
  17. vellum/client/types/execution_document_vellum_value.py +31 -0
  18. vellum/client/types/execution_image_vellum_value.py +31 -0
  19. vellum/client/types/execution_vellum_value.py +8 -0
  20. vellum/client/types/execution_video_vellum_value.py +31 -0
  21. vellum/client/types/image_input_request.py +30 -0
  22. vellum/client/types/logical_operator.py +1 -0
  23. vellum/client/types/node_input_compiled_audio_value.py +23 -0
  24. vellum/client/types/node_input_compiled_document_value.py +23 -0
  25. vellum/client/types/node_input_compiled_image_value.py +23 -0
  26. vellum/client/types/node_input_compiled_video_value.py +23 -0
  27. vellum/client/types/node_input_variable_compiled_value.py +8 -0
  28. vellum/client/types/prompt_deployment_input_request.py +13 -1
  29. vellum/client/types/prompt_request_audio_input.py +26 -0
  30. vellum/client/types/prompt_request_document_input.py +26 -0
  31. vellum/client/types/prompt_request_image_input.py +26 -0
  32. vellum/client/types/prompt_request_input.py +13 -1
  33. vellum/client/types/prompt_request_video_input.py +26 -0
  34. vellum/client/types/video_input_request.py +30 -0
  35. vellum/types/audio_input_request.py +3 -0
  36. vellum/types/delimiter_chunker_config.py +3 -0
  37. vellum/types/delimiter_chunker_config_request.py +3 -0
  38. vellum/types/delimiter_chunking.py +3 -0
  39. vellum/types/delimiter_chunking_request.py +3 -0
  40. vellum/types/document_input_request.py +3 -0
  41. vellum/types/execution_audio_vellum_value.py +3 -0
  42. vellum/types/execution_document_vellum_value.py +3 -0
  43. vellum/types/execution_image_vellum_value.py +3 -0
  44. vellum/types/execution_video_vellum_value.py +3 -0
  45. vellum/types/image_input_request.py +3 -0
  46. vellum/types/node_input_compiled_audio_value.py +3 -0
  47. vellum/types/node_input_compiled_document_value.py +3 -0
  48. vellum/types/node_input_compiled_image_value.py +3 -0
  49. vellum/types/node_input_compiled_video_value.py +3 -0
  50. vellum/types/prompt_request_audio_input.py +3 -0
  51. vellum/types/prompt_request_document_input.py +3 -0
  52. vellum/types/prompt_request_image_input.py +3 -0
  53. vellum/types/prompt_request_video_input.py +3 -0
  54. vellum/types/video_input_request.py +3 -0
  55. vellum/workflows/context.py +27 -9
  56. vellum/workflows/events/context.py +53 -78
  57. vellum/workflows/events/node.py +5 -5
  58. vellum/workflows/events/relational_threads.py +41 -0
  59. vellum/workflows/events/tests/test_basic_workflow.py +50 -0
  60. vellum/workflows/events/workflow.py +12 -1
  61. vellum/workflows/expressions/contains.py +7 -0
  62. vellum/workflows/expressions/tests/test_contains.py +175 -0
  63. vellum/workflows/graph/graph.py +52 -8
  64. vellum/workflows/graph/tests/test_graph.py +17 -0
  65. vellum/workflows/integrations/mcp_service.py +35 -5
  66. vellum/workflows/integrations/tests/test_mcp_service.py +81 -0
  67. vellum/workflows/nodes/core/error_node/node.py +4 -0
  68. vellum/workflows/nodes/core/map_node/node.py +7 -0
  69. vellum/workflows/nodes/core/map_node/tests/test_node.py +19 -0
  70. vellum/workflows/nodes/displayable/final_output_node/node.py +4 -0
  71. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +1 -1
  72. vellum/workflows/ports/node_ports.py +3 -0
  73. vellum/workflows/ports/port.py +7 -0
  74. vellum/workflows/state/context.py +35 -4
  75. vellum/workflows/utils/uuids.py +15 -0
  76. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/METADATA +1 -1
  77. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/RECORD +85 -39
  78. vellum_ee/workflows/display/nodes/vellum/error_node.py +1 -5
  79. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +1 -5
  80. vellum_ee/workflows/display/utils/events.py +24 -0
  81. vellum_ee/workflows/display/utils/tests/test_events.py +69 -0
  82. vellum_ee/workflows/tests/test_server.py +95 -0
  83. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/LICENSE +0 -0
  84. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/WHEEL +0 -0
  85. {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.node_input_compiled_video_value import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.prompt_request_audio_input import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.prompt_request_document_input import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.prompt_request_image_input import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.prompt_request_video_input import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.video_input_request import *
@@ -1,10 +1,11 @@
1
1
  from contextlib import contextmanager
2
2
  from dataclasses import field
3
3
  import threading
4
- from uuid import UUID
4
+ from uuid import UUID, uuid4
5
5
  from typing import Iterator, Optional, cast
6
6
 
7
7
  from vellum.client.core import UniversalBaseModel
8
+ from vellum.workflows.events.context import MonitoringContextStore
8
9
  from vellum.workflows.events.types import ParentContext
9
10
 
10
11
 
@@ -17,16 +18,37 @@ _CONTEXT_KEY = "_execution_context"
17
18
 
18
19
  local = threading.local()
19
20
 
21
+ monitoring_context_store = MonitoringContextStore()
22
+
20
23
 
21
24
  def get_execution_context() -> ExecutionContext:
22
- """Retrieve the current execution context."""
23
- return getattr(local, _CONTEXT_KEY, ExecutionContext())
25
+ """Get the current monitoring execution context, with intelligent fallback."""
26
+ context = getattr(local, _CONTEXT_KEY, ExecutionContext())
27
+ if context.parent_context:
28
+ return context
29
+
30
+ # If no thread-local context, try to restore from global store using current trace_id
31
+ context = monitoring_context_store.retrieve_context()
32
+ if context and context.parent_context:
33
+ set_execution_context(context)
34
+ return context
35
+ return ExecutionContext()
24
36
 
25
37
 
26
38
  def set_execution_context(context: ExecutionContext) -> None:
27
- """Set the current execution context."""
39
+ """Set the current monitoring execution context and persist it for cross-boundary access."""
28
40
  setattr(local, _CONTEXT_KEY, context)
29
41
 
42
+ # Always store in global store for cross-thread access
43
+ monitoring_context_store.store_context(context)
44
+
45
+
46
+ def clear_execution_context() -> None:
47
+ """Clear the current monitoring execution context."""
48
+ if hasattr(local, _CONTEXT_KEY):
49
+ delattr(local, _CONTEXT_KEY)
50
+ monitoring_context_store.clear_context()
51
+
30
52
 
31
53
  def get_parent_context() -> ParentContext:
32
54
  return cast(ParentContext, get_execution_context().parent_context)
@@ -38,11 +60,7 @@ def execution_context(
38
60
  ) -> Iterator[None]:
39
61
  """Context manager for handling execution context."""
40
62
  prev_context = get_execution_context()
41
- set_trace_id = (
42
- prev_context.trace_id
43
- if int(prev_context.trace_id)
44
- else trace_id or UUID("00000000-0000-0000-0000-000000000000")
45
- )
63
+ set_trace_id = prev_context.trace_id if int(prev_context.trace_id) else trace_id or uuid4()
46
64
  set_parent_context = parent_context or prev_context.parent_context
47
65
  set_context = ExecutionContext(parent_context=set_parent_context, trace_id=set_trace_id)
48
66
  try:
@@ -2,110 +2,85 @@
2
2
 
3
3
  import threading
4
4
  from uuid import UUID
5
- from typing import Dict, Optional
5
+ from typing import TYPE_CHECKING, Dict, Optional
6
6
 
7
- from vellum.workflows.context import ExecutionContext
7
+ if TYPE_CHECKING:
8
+ from vellum.workflows.context import ExecutionContext
8
9
 
9
10
  DEFAULT_TRACE_ID = UUID("00000000-0000-0000-0000-000000000000")
10
11
 
11
- # Thread-local storage for monitoring execution context
12
- _monitoring_execution_context: threading.local = threading.local()
13
- # Thread-local storage for current span_id
14
- _current_span_id: threading.local = threading.local()
15
12
 
16
-
17
- class _MonitoringContextStore:
13
+ class MonitoringContextStore:
18
14
  """
19
15
  thread-safe storage for monitoring contexts.
20
16
  handles context persistence and retrieval across threads.
21
- relies on the execution context manager for manual retrieval
22
17
  """
23
18
 
24
19
  def __init__(self):
25
20
  self._lock = threading.Lock()
26
- self._contexts: Dict[str, ExecutionContext] = {}
27
- self._thread_contexts: Dict[int, ExecutionContext] = {}
28
- self._current_trace_id: Optional[UUID] = None
29
-
30
- def set_current_trace_id(self, trace_id: UUID) -> None:
31
- """Set the current active trace_id that should be used by all threads."""
32
- if trace_id != DEFAULT_TRACE_ID:
33
- with self._lock:
34
- self._current_trace_id = trace_id
21
+ self._contexts: Dict[str, "ExecutionContext"] = {}
35
22
 
36
23
  def get_current_trace_id(self) -> Optional[UUID]:
37
24
  """Get the current active trace_id that should be used by all threads."""
38
25
  with self._lock:
39
- return self._current_trace_id
40
-
41
- def set_current_span_id(self, span_id: UUID) -> None:
42
- """Set the current active span_id for this thread."""
43
- _current_span_id.span_id = span_id
44
-
45
- def get_current_span_id(self) -> Optional[UUID]:
46
- """Get the current active span_id for this thread."""
47
- return getattr(_current_span_id, "span_id", None)
48
-
49
- def store_context(self, context: Optional[ExecutionContext]) -> None:
50
- """Store monitoring parent context using multiple keys for reliable retrieval."""
51
- if not context or context.parent_context is None:
52
- return
26
+ current_context = self.retrieve_context()
27
+ return current_context.trace_id if current_context else None
53
28
 
29
+ def store_context(self, context: "ExecutionContext") -> None:
30
+ """Store monitoring parent context using trace:span:thread keys."""
54
31
  thread_id = threading.get_ident()
55
- trace_id = self.get_current_trace_id()
56
- if context.trace_id != DEFAULT_TRACE_ID and trace_id is None:
57
- self.set_current_trace_id(context.trace_id)
32
+ current_thread = threading.current_thread()
33
+ trace_id = context.trace_id
58
34
 
59
35
  with self._lock:
60
- # Use trace:span:thread for unique context storage
61
- trace_span_thread_key = (
62
- f"trace:{str(trace_id)}:span:{str(context.parent_context.span_id)}:thread:{thread_id}"
63
- )
64
- self._contexts[trace_span_thread_key] = context
65
-
66
- def retrieve_context(self, trace_id: UUID, span_id: Optional[UUID] = None) -> Optional[ExecutionContext]:
67
- """Retrieve monitoring parent context with multiple fallback strategies."""
68
- thread_id = threading.get_ident()
69
- with self._lock:
70
- if not span_id:
71
- span_id = getattr(_current_span_id, "span_id", None)
72
- if not span_id:
73
- return None
74
-
75
- span_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{thread_id}"
76
- if span_key in self._contexts:
77
- result = self._contexts[span_key]
78
- return result
36
+ # Get span_id from RelationalThread first
37
+ span_id = None
79
38
 
80
- return None
39
+ if context.parent_context and hasattr(context.parent_context, "span_id"):
40
+ span_id = context.parent_context.span_id
81
41
 
42
+ if not span_id and hasattr(current_thread, "get_parent_span_id"):
43
+ span_id = current_thread.get_parent_span_id()
82
44
 
83
- # Global instance for cross-boundary context persistence
84
- _monitoring_context_store = _MonitoringContextStore()
45
+ # Always use trace:span:thread format - require span_id
46
+ if span_id:
47
+ context_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{thread_id}"
48
+ self._contexts[context_key] = context
85
49
 
50
+ def retrieve_context(self) -> Optional["ExecutionContext"]:
51
+ """Retrieve monitoring parent context using trace:span:thread keys."""
52
+ current_thread = threading.current_thread()
53
+ current_thread_id = current_thread.ident
86
54
 
87
- def get_monitoring_execution_context() -> ExecutionContext:
88
- """Get the current monitoring execution context, with intelligent fallback."""
89
- if hasattr(_monitoring_execution_context, "context"):
90
- context = _monitoring_execution_context.context
91
- if context.trace_id != DEFAULT_TRACE_ID and context.parent_context:
92
- return context
55
+ # Get trace_id and span_id directly from the thread if it's a RelationalThread
56
+ trace_id = None
57
+ span_id = None
58
+ if hasattr(current_thread, "get_trace_id"):
59
+ trace_id = current_thread.get_trace_id()
60
+ if hasattr(current_thread, "get_parent_span_id"):
61
+ span_id = current_thread.get_parent_span_id()
93
62
 
94
- # If no thread-local context, try to restore from global store using current trace_id
95
- trace_id = _monitoring_context_store.get_current_trace_id()
96
- span_id = _current_span_id.span_id if hasattr(_current_span_id, "span_id") else None
97
- if trace_id:
98
- if trace_id != DEFAULT_TRACE_ID:
99
- context = _monitoring_context_store.retrieve_context(trace_id, span_id)
100
- if context:
101
- _monitoring_execution_context.context = context
102
- return context
103
- return ExecutionContext()
63
+ # Require both trace_id and span_id - no fallback searching
64
+ if not trace_id or not span_id:
65
+ return None
104
66
 
67
+ with self._lock:
68
+ # Try current thread
69
+ current_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{current_thread_id}"
70
+ if current_key in self._contexts:
71
+ return self._contexts[current_key]
72
+
73
+ # Try parent thread with same trace and span
74
+ if hasattr(current_thread, "get_parent_thread"):
75
+ parent_thread_id = current_thread.get_parent_thread()
76
+ if parent_thread_id:
77
+ parent_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{parent_thread_id}"
78
+ if parent_key in self._contexts:
79
+ return self._contexts[parent_key]
105
80
 
106
- def set_monitoring_execution_context(context: ExecutionContext) -> None:
107
- """Set the current monitoring execution context and persist it for cross-boundary access."""
108
- _monitoring_execution_context.context = context
81
+ return None
109
82
 
110
- if context.trace_id and context.parent_context:
111
- _monitoring_context_store.store_context(context)
83
+ def clear_context(self):
84
+ """Clear all stored contexts."""
85
+ with self._lock:
86
+ self._contexts.clear()
@@ -1,6 +1,6 @@
1
1
  from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Set, Type, Union
2
2
 
3
- from pydantic import field_serializer, model_serializer
3
+ from pydantic import SerializationInfo, field_serializer, model_serializer
4
4
 
5
5
  from vellum.client.core.pydantic_utilities import UniversalBaseModel
6
6
  from vellum.workflows.errors import WorkflowError
@@ -82,8 +82,8 @@ class NodeExecutionStreamingEvent(_BaseNodeEvent):
82
82
  return self.body.invoked_ports
83
83
 
84
84
  @model_serializer(mode="plain", when_used="json")
85
- def serialize_model(self) -> Any:
86
- serialized = super().serialize_model()
85
+ def serialize_model(self, info: SerializationInfo) -> Any:
86
+ serialized = super().serialize_model(info)
87
87
  if (
88
88
  "body" in serialized
89
89
  and isinstance(serialized["body"], dict)
@@ -127,8 +127,8 @@ class NodeExecutionFulfilledEvent(_BaseNodeEvent, Generic[OutputsType]):
127
127
  return self.body.mocked
128
128
 
129
129
  @model_serializer(mode="plain", when_used="json")
130
- def serialize_model(self) -> Any:
131
- serialized = super().serialize_model()
130
+ def serialize_model(self, info: SerializationInfo) -> Any:
131
+ serialized = super().serialize_model(info)
132
132
  if (
133
133
  "body" in serialized
134
134
  and isinstance(serialized["body"], dict)
@@ -0,0 +1,41 @@
1
+ import threading
2
+ from uuid import UUID
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ if TYPE_CHECKING:
6
+ from vellum.workflows.context import ExecutionContext
7
+
8
+
9
+ class RelationalThread(threading.Thread):
10
+ _parent_thread: Optional[int] = None
11
+ _trace_id: Optional[UUID] = None
12
+ _parent_span_id: Optional[UUID] = None
13
+
14
+ def __init__(self, *args, execution_context: Optional["ExecutionContext"] = None, **kwargs):
15
+ self._collect_parent_context(execution_context)
16
+ threading.Thread.__init__(self, *args, **kwargs)
17
+
18
+ def _collect_parent_context(self, execution_context: Optional["ExecutionContext"] = None) -> None:
19
+ """Collect parent thread ID, trace ID, and parent span ID from passed execution context."""
20
+ self._parent_thread = threading.get_ident()
21
+
22
+ # Only use explicitly passed execution context
23
+ if execution_context:
24
+ self._trace_id = execution_context.trace_id
25
+ self._parent_span_id = (
26
+ execution_context.parent_context.span_id
27
+ if execution_context.parent_context and hasattr(execution_context.parent_context, "span_id")
28
+ else None
29
+ )
30
+ else:
31
+ self._trace_id = None
32
+ self._parent_span_id = None
33
+
34
+ def get_parent_thread(self) -> Optional[int]:
35
+ return self._parent_thread
36
+
37
+ def get_trace_id(self) -> Optional[UUID]:
38
+ return self._trace_id
39
+
40
+ def get_parent_span_id(self) -> Optional[UUID]:
41
+ return self._parent_span_id
@@ -0,0 +1,50 @@
1
+ import logging
2
+ from uuid import uuid4
3
+
4
+ from vellum.workflows import BaseWorkflow
5
+ from vellum.workflows.context import execution_context
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.workflows.event_filters import all_workflow_event_filter
8
+
9
+
10
+ class StartNode(BaseNode):
11
+ pass
12
+
13
+
14
+ class TrivialWorkflow(BaseWorkflow):
15
+ graph = StartNode
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def test_basic_workflow_monitoring_context_flow():
22
+ """Test that monitoring creates the correct workflow→node context hierarchy using streamed events.
23
+ What's missing:
24
+ - span_id from previous event mapping to parent context
25
+ """
26
+
27
+ workflow = TrivialWorkflow()
28
+
29
+ with execution_context(trace_id=uuid4()):
30
+ events = list(workflow.stream(event_filter=all_workflow_event_filter))
31
+
32
+ # Verify workflow succeeded
33
+ assert len(events) >= 2
34
+ assert events[0].name == "workflow.execution.initiated"
35
+ assert events[-1].name == "workflow.execution.fulfilled"
36
+
37
+ # Collect all events with parent context
38
+ events_with_parent = [event for event in events if event.parent is not None]
39
+
40
+ assert len(events_with_parent) > 0, "Expected at least some events with parent context"
41
+
42
+ # Filter for node events
43
+ node_events = [event for event in events if event.name.startswith("node.")]
44
+ assert len(node_events) > 0, "Expected at least some node events"
45
+
46
+ # Verify each node event has the workflow as its parent context
47
+ for event in node_events:
48
+ assert event.parent is not None, "Node event should have parent context"
49
+ assert event.parent.type == "WORKFLOW", "Node event parent should be workflow"
50
+ assert event.parent.workflow_definition.name == "TrivialWorkflow", "Parent workflow name mismatch"
@@ -2,7 +2,7 @@ from uuid import UUID
2
2
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Literal, Optional, Type, Union
3
3
  from typing_extensions import TypeGuard
4
4
 
5
- from pydantic import field_serializer
5
+ from pydantic import SerializationInfo, field_serializer
6
6
 
7
7
  from vellum.client.core.pydantic_utilities import UniversalBaseModel
8
8
  from vellum.workflows.errors import WorkflowError
@@ -101,6 +101,17 @@ class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType, St
101
101
  def initial_state(self) -> Optional[StateType]:
102
102
  return self.body.initial_state
103
103
 
104
+ @field_serializer("body")
105
+ def serialize_body(
106
+ self, body: WorkflowExecutionInitiatedBody[InputsType, StateType], info: SerializationInfo
107
+ ) -> WorkflowExecutionInitiatedBody[InputsType, StateType]:
108
+ context = info.context if info and hasattr(info, "context") else {}
109
+ if context and "event_enricher" in context and callable(context["event_enricher"]):
110
+ event = context["event_enricher"](self)
111
+ return event.body
112
+ else:
113
+ return body
114
+
104
115
 
105
116
  class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
106
117
  output: BaseOutput
@@ -35,4 +35,11 @@ class ContainsExpression(BaseDescriptor[bool], Generic[LHS, RHS]):
35
35
  )
36
36
 
37
37
  rhs = resolve_value(self._rhs, state)
38
+
39
+ if isinstance(rhs, dict):
40
+ raise InvalidExpressionException(
41
+ "Cannot use dict as right-hand side of contains operation. "
42
+ "Use dict keys/values or convert to strings for comparison."
43
+ )
44
+
38
45
  return rhs in lhs
@@ -0,0 +1,175 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.constants import undefined
4
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
5
+ from vellum.workflows.expressions.contains import ContainsExpression
6
+ from vellum.workflows.references.constant import ConstantValueReference
7
+ from vellum.workflows.state.base import BaseState
8
+
9
+
10
+ class TestState(BaseState):
11
+ dict_value: dict = {"key": "value"}
12
+ list_value: list = [1, 2, 3]
13
+ string_value: str = "hello world"
14
+
15
+
16
+ def test_dict_contains_dict_raises_error():
17
+ """
18
+ Tests that ContainsExpression raises clear error for dict-contains-dict scenarios.
19
+ """
20
+ state = TestState()
21
+ lhs_dict = {"foo": "bar"}
22
+ rhs_dict = {"foo": "bar"}
23
+
24
+ expression = ContainsExpression(lhs=lhs_dict, rhs=rhs_dict)
25
+
26
+ with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
27
+ expression.resolve(state)
28
+
29
+
30
+ def test_dict_contains_different_dict_raises_error():
31
+ """
32
+ Tests that ContainsExpression raises clear error for different dict-contains-dict scenarios.
33
+ """
34
+ state = TestState()
35
+ lhs_dict = {"foo": "bar"}
36
+ rhs_dict = {"hello": "world"}
37
+
38
+ expression = ContainsExpression(lhs=lhs_dict, rhs=rhs_dict)
39
+
40
+ with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
41
+ expression.resolve(state)
42
+
43
+
44
+ def test_string_contains_dict_raises_error():
45
+ """
46
+ Tests that ContainsExpression raises clear error for string-contains-dict scenarios.
47
+ """
48
+ state = TestState()
49
+ lhs_string = 'Response: {"status": "success"} was returned'
50
+ rhs_dict = {"status": "success"}
51
+
52
+ expression = ContainsExpression(lhs=lhs_string, rhs=rhs_dict)
53
+
54
+ with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
55
+ expression.resolve(state)
56
+
57
+
58
+ def test_nested_dict_contains_dict_raises_error():
59
+ """
60
+ Tests that ContainsExpression raises clear error for nested dict scenarios.
61
+ """
62
+ state = TestState()
63
+ lhs_dict = {"user": {"name": "john", "age": 30}}
64
+ rhs_dict = {"age": 30, "name": "john"}
65
+
66
+ expression = ContainsExpression(lhs=lhs_dict, rhs=rhs_dict)
67
+
68
+ with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
69
+ expression.resolve(state)
70
+
71
+
72
+ def test_list_contains_string():
73
+ """
74
+ Tests that ContainsExpression preserves original list functionality.
75
+ """
76
+ state = TestState()
77
+
78
+ expression = TestState.list_value.contains(2)
79
+ result = expression.resolve(state)
80
+
81
+ assert result is True
82
+
83
+
84
+ def test_string_contains_substring():
85
+ """
86
+ Tests that ContainsExpression preserves original string functionality.
87
+ """
88
+ state = TestState()
89
+
90
+ expression = TestState.string_value.contains("world")
91
+ result = expression.resolve(state)
92
+
93
+ assert result is True
94
+
95
+
96
+ def test_set_contains_item():
97
+ """
98
+ Tests that ContainsExpression works with sets.
99
+ """
100
+ state = TestState()
101
+ lhs_set = {1, 2, 3}
102
+ rhs_item = 2
103
+
104
+ expression = ContainsExpression(lhs=lhs_set, rhs=rhs_item)
105
+ result = expression.resolve(state)
106
+
107
+ assert result is True
108
+
109
+
110
+ def test_tuple_contains_item():
111
+ """
112
+ Tests that ContainsExpression works with tuples.
113
+ """
114
+ state = TestState()
115
+ lhs_tuple = (1, 2, 3)
116
+ rhs_item = 2
117
+
118
+ expression = ContainsExpression(lhs=lhs_tuple, rhs=rhs_item)
119
+ result = expression.resolve(state)
120
+
121
+ assert result is True
122
+
123
+
124
+ def test_invalid_lhs_type():
125
+ """
126
+ Tests that ContainsExpression raises exception for invalid LHS types.
127
+ """
128
+
129
+ class NoContainsSupport:
130
+ pass
131
+
132
+ state = TestState()
133
+ no_contains_obj = NoContainsSupport()
134
+ expression = ContainsExpression(lhs=no_contains_obj, rhs="test")
135
+
136
+ with pytest.raises(
137
+ InvalidExpressionException, match="Expected a LHS that supported `contains`, got `NoContainsSupport`"
138
+ ):
139
+ expression.resolve(state)
140
+
141
+
142
+ def test_undefined_lhs_returns_false():
143
+ """
144
+ Tests that ContainsExpression returns False for undefined LHS.
145
+ """
146
+ state = TestState()
147
+ expression = ContainsExpression(lhs=undefined, rhs="test")
148
+
149
+ result = expression.resolve(state)
150
+
151
+ assert result is False
152
+
153
+
154
+ def test_contains_with_constant_value_reference():
155
+ """
156
+ Tests ContainsExpression with ConstantValueReference for valid operations.
157
+ """
158
+ state = TestState()
159
+ lhs_ref = ConstantValueReference([1, 2, 3])
160
+ rhs_ref = ConstantValueReference(2)
161
+
162
+ expression: ContainsExpression = ContainsExpression(lhs=lhs_ref, rhs=rhs_ref)
163
+ result = expression.resolve(state)
164
+
165
+ assert result is True
166
+
167
+
168
+ def test_expression_metadata():
169
+ """
170
+ Tests that ContainsExpression has correct name and types properties.
171
+ """
172
+ expression = ContainsExpression(lhs=[1, 2, 3], rhs=2)
173
+
174
+ assert expression.name == "[1, 2, 3] contains 2"
175
+ assert expression.types == (bool,)