vellum-ai 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +40 -0
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/core/pydantic_utilities.py +3 -2
- vellum/client/reference.md +16 -0
- vellum/client/resources/workflow_executions/client.py +28 -4
- vellum/client/resources/workflow_executions/raw_client.py +32 -2
- vellum/client/types/__init__.py +40 -0
- vellum/client/types/audio_input_request.py +30 -0
- vellum/client/types/delimiter_chunker_config.py +20 -0
- vellum/client/types/delimiter_chunker_config_request.py +20 -0
- vellum/client/types/delimiter_chunking.py +21 -0
- vellum/client/types/delimiter_chunking_request.py +21 -0
- vellum/client/types/document_index_chunking.py +4 -1
- vellum/client/types/document_index_chunking_request.py +2 -1
- vellum/client/types/document_input_request.py +30 -0
- vellum/client/types/execution_audio_vellum_value.py +31 -0
- vellum/client/types/execution_document_vellum_value.py +31 -0
- vellum/client/types/execution_image_vellum_value.py +31 -0
- vellum/client/types/execution_vellum_value.py +8 -0
- vellum/client/types/execution_video_vellum_value.py +31 -0
- vellum/client/types/image_input_request.py +30 -0
- vellum/client/types/logical_operator.py +1 -0
- vellum/client/types/node_input_compiled_audio_value.py +23 -0
- vellum/client/types/node_input_compiled_document_value.py +23 -0
- vellum/client/types/node_input_compiled_image_value.py +23 -0
- vellum/client/types/node_input_compiled_video_value.py +23 -0
- vellum/client/types/node_input_variable_compiled_value.py +8 -0
- vellum/client/types/prompt_deployment_input_request.py +13 -1
- vellum/client/types/prompt_request_audio_input.py +26 -0
- vellum/client/types/prompt_request_document_input.py +26 -0
- vellum/client/types/prompt_request_image_input.py +26 -0
- vellum/client/types/prompt_request_input.py +13 -1
- vellum/client/types/prompt_request_video_input.py +26 -0
- vellum/client/types/video_input_request.py +30 -0
- vellum/types/audio_input_request.py +3 -0
- vellum/types/delimiter_chunker_config.py +3 -0
- vellum/types/delimiter_chunker_config_request.py +3 -0
- vellum/types/delimiter_chunking.py +3 -0
- vellum/types/delimiter_chunking_request.py +3 -0
- vellum/types/document_input_request.py +3 -0
- vellum/types/execution_audio_vellum_value.py +3 -0
- vellum/types/execution_document_vellum_value.py +3 -0
- vellum/types/execution_image_vellum_value.py +3 -0
- vellum/types/execution_video_vellum_value.py +3 -0
- vellum/types/image_input_request.py +3 -0
- vellum/types/node_input_compiled_audio_value.py +3 -0
- vellum/types/node_input_compiled_document_value.py +3 -0
- vellum/types/node_input_compiled_image_value.py +3 -0
- vellum/types/node_input_compiled_video_value.py +3 -0
- vellum/types/prompt_request_audio_input.py +3 -0
- vellum/types/prompt_request_document_input.py +3 -0
- vellum/types/prompt_request_image_input.py +3 -0
- vellum/types/prompt_request_video_input.py +3 -0
- vellum/types/video_input_request.py +3 -0
- vellum/workflows/context.py +27 -9
- vellum/workflows/events/context.py +53 -78
- vellum/workflows/events/node.py +5 -5
- vellum/workflows/events/relational_threads.py +41 -0
- vellum/workflows/events/tests/test_basic_workflow.py +50 -0
- vellum/workflows/events/workflow.py +12 -1
- vellum/workflows/expressions/contains.py +7 -0
- vellum/workflows/expressions/tests/test_contains.py +175 -0
- vellum/workflows/graph/graph.py +52 -8
- vellum/workflows/graph/tests/test_graph.py +17 -0
- vellum/workflows/integrations/mcp_service.py +35 -5
- vellum/workflows/integrations/tests/test_mcp_service.py +81 -0
- vellum/workflows/nodes/core/error_node/node.py +4 -0
- vellum/workflows/nodes/core/map_node/node.py +7 -0
- vellum/workflows/nodes/core/map_node/tests/test_node.py +19 -0
- vellum/workflows/nodes/displayable/final_output_node/node.py +4 -0
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +1 -1
- vellum/workflows/ports/node_ports.py +3 -0
- vellum/workflows/ports/port.py +7 -0
- vellum/workflows/state/context.py +35 -4
- vellum/workflows/utils/uuids.py +15 -0
- {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/METADATA +1 -1
- {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/RECORD +85 -39
- vellum_ee/workflows/display/nodes/vellum/error_node.py +1 -5
- vellum_ee/workflows/display/nodes/vellum/final_output_node.py +1 -5
- vellum_ee/workflows/display/utils/events.py +24 -0
- vellum_ee/workflows/display/utils/tests/test_events.py +69 -0
- vellum_ee/workflows/tests/test_server.py +95 -0
- {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/LICENSE +0 -0
- {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/WHEEL +0 -0
- {vellum_ai-1.2.2.dist-info → vellum_ai-1.2.3.dist-info}/entry_points.txt +0 -0
vellum/workflows/context.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
from contextlib import contextmanager
|
2
2
|
from dataclasses import field
|
3
3
|
import threading
|
4
|
-
from uuid import UUID
|
4
|
+
from uuid import UUID, uuid4
|
5
5
|
from typing import Iterator, Optional, cast
|
6
6
|
|
7
7
|
from vellum.client.core import UniversalBaseModel
|
8
|
+
from vellum.workflows.events.context import MonitoringContextStore
|
8
9
|
from vellum.workflows.events.types import ParentContext
|
9
10
|
|
10
11
|
|
@@ -17,16 +18,37 @@ _CONTEXT_KEY = "_execution_context"
|
|
17
18
|
|
18
19
|
local = threading.local()
|
19
20
|
|
21
|
+
monitoring_context_store = MonitoringContextStore()
|
22
|
+
|
20
23
|
|
21
24
|
def get_execution_context() -> ExecutionContext:
|
22
|
-
"""
|
23
|
-
|
25
|
+
"""Get the current monitoring execution context, with intelligent fallback."""
|
26
|
+
context = getattr(local, _CONTEXT_KEY, ExecutionContext())
|
27
|
+
if context.parent_context:
|
28
|
+
return context
|
29
|
+
|
30
|
+
# If no thread-local context, try to restore from global store using current trace_id
|
31
|
+
context = monitoring_context_store.retrieve_context()
|
32
|
+
if context and context.parent_context:
|
33
|
+
set_execution_context(context)
|
34
|
+
return context
|
35
|
+
return ExecutionContext()
|
24
36
|
|
25
37
|
|
26
38
|
def set_execution_context(context: ExecutionContext) -> None:
|
27
|
-
"""Set the current execution context."""
|
39
|
+
"""Set the current monitoring execution context and persist it for cross-boundary access."""
|
28
40
|
setattr(local, _CONTEXT_KEY, context)
|
29
41
|
|
42
|
+
# Always store in global store for cross-thread access
|
43
|
+
monitoring_context_store.store_context(context)
|
44
|
+
|
45
|
+
|
46
|
+
def clear_execution_context() -> None:
|
47
|
+
"""Clear the current monitoring execution context."""
|
48
|
+
if hasattr(local, _CONTEXT_KEY):
|
49
|
+
delattr(local, _CONTEXT_KEY)
|
50
|
+
monitoring_context_store.clear_context()
|
51
|
+
|
30
52
|
|
31
53
|
def get_parent_context() -> ParentContext:
|
32
54
|
return cast(ParentContext, get_execution_context().parent_context)
|
@@ -38,11 +60,7 @@ def execution_context(
|
|
38
60
|
) -> Iterator[None]:
|
39
61
|
"""Context manager for handling execution context."""
|
40
62
|
prev_context = get_execution_context()
|
41
|
-
set_trace_id = (
|
42
|
-
prev_context.trace_id
|
43
|
-
if int(prev_context.trace_id)
|
44
|
-
else trace_id or UUID("00000000-0000-0000-0000-000000000000")
|
45
|
-
)
|
63
|
+
set_trace_id = prev_context.trace_id if int(prev_context.trace_id) else trace_id or uuid4()
|
46
64
|
set_parent_context = parent_context or prev_context.parent_context
|
47
65
|
set_context = ExecutionContext(parent_context=set_parent_context, trace_id=set_trace_id)
|
48
66
|
try:
|
@@ -2,110 +2,85 @@
|
|
2
2
|
|
3
3
|
import threading
|
4
4
|
from uuid import UUID
|
5
|
-
from typing import Dict, Optional
|
5
|
+
from typing import TYPE_CHECKING, Dict, Optional
|
6
6
|
|
7
|
-
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from vellum.workflows.context import ExecutionContext
|
8
9
|
|
9
10
|
DEFAULT_TRACE_ID = UUID("00000000-0000-0000-0000-000000000000")
|
10
11
|
|
11
|
-
# Thread-local storage for monitoring execution context
|
12
|
-
_monitoring_execution_context: threading.local = threading.local()
|
13
|
-
# Thread-local storage for current span_id
|
14
|
-
_current_span_id: threading.local = threading.local()
|
15
12
|
|
16
|
-
|
17
|
-
class _MonitoringContextStore:
|
13
|
+
class MonitoringContextStore:
|
18
14
|
"""
|
19
15
|
thread-safe storage for monitoring contexts.
|
20
16
|
handles context persistence and retrieval across threads.
|
21
|
-
relies on the execution context manager for manual retrieval
|
22
17
|
"""
|
23
18
|
|
24
19
|
def __init__(self):
|
25
20
|
self._lock = threading.Lock()
|
26
|
-
self._contexts: Dict[str, ExecutionContext] = {}
|
27
|
-
self._thread_contexts: Dict[int, ExecutionContext] = {}
|
28
|
-
self._current_trace_id: Optional[UUID] = None
|
29
|
-
|
30
|
-
def set_current_trace_id(self, trace_id: UUID) -> None:
|
31
|
-
"""Set the current active trace_id that should be used by all threads."""
|
32
|
-
if trace_id != DEFAULT_TRACE_ID:
|
33
|
-
with self._lock:
|
34
|
-
self._current_trace_id = trace_id
|
21
|
+
self._contexts: Dict[str, "ExecutionContext"] = {}
|
35
22
|
|
36
23
|
def get_current_trace_id(self) -> Optional[UUID]:
|
37
24
|
"""Get the current active trace_id that should be used by all threads."""
|
38
25
|
with self._lock:
|
39
|
-
|
40
|
-
|
41
|
-
def set_current_span_id(self, span_id: UUID) -> None:
|
42
|
-
"""Set the current active span_id for this thread."""
|
43
|
-
_current_span_id.span_id = span_id
|
44
|
-
|
45
|
-
def get_current_span_id(self) -> Optional[UUID]:
|
46
|
-
"""Get the current active span_id for this thread."""
|
47
|
-
return getattr(_current_span_id, "span_id", None)
|
48
|
-
|
49
|
-
def store_context(self, context: Optional[ExecutionContext]) -> None:
|
50
|
-
"""Store monitoring parent context using multiple keys for reliable retrieval."""
|
51
|
-
if not context or context.parent_context is None:
|
52
|
-
return
|
26
|
+
current_context = self.retrieve_context()
|
27
|
+
return current_context.trace_id if current_context else None
|
53
28
|
|
29
|
+
def store_context(self, context: "ExecutionContext") -> None:
|
30
|
+
"""Store monitoring parent context using trace:span:thread keys."""
|
54
31
|
thread_id = threading.get_ident()
|
55
|
-
|
56
|
-
|
57
|
-
self.set_current_trace_id(context.trace_id)
|
32
|
+
current_thread = threading.current_thread()
|
33
|
+
trace_id = context.trace_id
|
58
34
|
|
59
35
|
with self._lock:
|
60
|
-
#
|
61
|
-
|
62
|
-
f"trace:{str(trace_id)}:span:{str(context.parent_context.span_id)}:thread:{thread_id}"
|
63
|
-
)
|
64
|
-
self._contexts[trace_span_thread_key] = context
|
65
|
-
|
66
|
-
def retrieve_context(self, trace_id: UUID, span_id: Optional[UUID] = None) -> Optional[ExecutionContext]:
|
67
|
-
"""Retrieve monitoring parent context with multiple fallback strategies."""
|
68
|
-
thread_id = threading.get_ident()
|
69
|
-
with self._lock:
|
70
|
-
if not span_id:
|
71
|
-
span_id = getattr(_current_span_id, "span_id", None)
|
72
|
-
if not span_id:
|
73
|
-
return None
|
74
|
-
|
75
|
-
span_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{thread_id}"
|
76
|
-
if span_key in self._contexts:
|
77
|
-
result = self._contexts[span_key]
|
78
|
-
return result
|
36
|
+
# Get span_id from RelationalThread first
|
37
|
+
span_id = None
|
79
38
|
|
80
|
-
|
39
|
+
if context.parent_context and hasattr(context.parent_context, "span_id"):
|
40
|
+
span_id = context.parent_context.span_id
|
81
41
|
|
42
|
+
if not span_id and hasattr(current_thread, "get_parent_span_id"):
|
43
|
+
span_id = current_thread.get_parent_span_id()
|
82
44
|
|
83
|
-
#
|
84
|
-
|
45
|
+
# Always use trace:span:thread format - require span_id
|
46
|
+
if span_id:
|
47
|
+
context_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{thread_id}"
|
48
|
+
self._contexts[context_key] = context
|
85
49
|
|
50
|
+
def retrieve_context(self) -> Optional["ExecutionContext"]:
|
51
|
+
"""Retrieve monitoring parent context using trace:span:thread keys."""
|
52
|
+
current_thread = threading.current_thread()
|
53
|
+
current_thread_id = current_thread.ident
|
86
54
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
55
|
+
# Get trace_id and span_id directly from the thread if it's a RelationalThread
|
56
|
+
trace_id = None
|
57
|
+
span_id = None
|
58
|
+
if hasattr(current_thread, "get_trace_id"):
|
59
|
+
trace_id = current_thread.get_trace_id()
|
60
|
+
if hasattr(current_thread, "get_parent_span_id"):
|
61
|
+
span_id = current_thread.get_parent_span_id()
|
93
62
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
if trace_id:
|
98
|
-
if trace_id != DEFAULT_TRACE_ID:
|
99
|
-
context = _monitoring_context_store.retrieve_context(trace_id, span_id)
|
100
|
-
if context:
|
101
|
-
_monitoring_execution_context.context = context
|
102
|
-
return context
|
103
|
-
return ExecutionContext()
|
63
|
+
# Require both trace_id and span_id - no fallback searching
|
64
|
+
if not trace_id or not span_id:
|
65
|
+
return None
|
104
66
|
|
67
|
+
with self._lock:
|
68
|
+
# Try current thread
|
69
|
+
current_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{current_thread_id}"
|
70
|
+
if current_key in self._contexts:
|
71
|
+
return self._contexts[current_key]
|
72
|
+
|
73
|
+
# Try parent thread with same trace and span
|
74
|
+
if hasattr(current_thread, "get_parent_thread"):
|
75
|
+
parent_thread_id = current_thread.get_parent_thread()
|
76
|
+
if parent_thread_id:
|
77
|
+
parent_key = f"trace:{str(trace_id)}:span:{str(span_id)}:thread:{parent_thread_id}"
|
78
|
+
if parent_key in self._contexts:
|
79
|
+
return self._contexts[parent_key]
|
105
80
|
|
106
|
-
|
107
|
-
"""Set the current monitoring execution context and persist it for cross-boundary access."""
|
108
|
-
_monitoring_execution_context.context = context
|
81
|
+
return None
|
109
82
|
|
110
|
-
|
111
|
-
|
83
|
+
def clear_context(self):
|
84
|
+
"""Clear all stored contexts."""
|
85
|
+
with self._lock:
|
86
|
+
self._contexts.clear()
|
vellum/workflows/events/node.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Set, Type, Union
|
2
2
|
|
3
|
-
from pydantic import field_serializer, model_serializer
|
3
|
+
from pydantic import SerializationInfo, field_serializer, model_serializer
|
4
4
|
|
5
5
|
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
6
6
|
from vellum.workflows.errors import WorkflowError
|
@@ -82,8 +82,8 @@ class NodeExecutionStreamingEvent(_BaseNodeEvent):
|
|
82
82
|
return self.body.invoked_ports
|
83
83
|
|
84
84
|
@model_serializer(mode="plain", when_used="json")
|
85
|
-
def serialize_model(self) -> Any:
|
86
|
-
serialized = super().serialize_model()
|
85
|
+
def serialize_model(self, info: SerializationInfo) -> Any:
|
86
|
+
serialized = super().serialize_model(info)
|
87
87
|
if (
|
88
88
|
"body" in serialized
|
89
89
|
and isinstance(serialized["body"], dict)
|
@@ -127,8 +127,8 @@ class NodeExecutionFulfilledEvent(_BaseNodeEvent, Generic[OutputsType]):
|
|
127
127
|
return self.body.mocked
|
128
128
|
|
129
129
|
@model_serializer(mode="plain", when_used="json")
|
130
|
-
def serialize_model(self) -> Any:
|
131
|
-
serialized = super().serialize_model()
|
130
|
+
def serialize_model(self, info: SerializationInfo) -> Any:
|
131
|
+
serialized = super().serialize_model(info)
|
132
132
|
if (
|
133
133
|
"body" in serialized
|
134
134
|
and isinstance(serialized["body"], dict)
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import threading
|
2
|
+
from uuid import UUID
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
4
|
+
|
5
|
+
if TYPE_CHECKING:
|
6
|
+
from vellum.workflows.context import ExecutionContext
|
7
|
+
|
8
|
+
|
9
|
+
class RelationalThread(threading.Thread):
|
10
|
+
_parent_thread: Optional[int] = None
|
11
|
+
_trace_id: Optional[UUID] = None
|
12
|
+
_parent_span_id: Optional[UUID] = None
|
13
|
+
|
14
|
+
def __init__(self, *args, execution_context: Optional["ExecutionContext"] = None, **kwargs):
|
15
|
+
self._collect_parent_context(execution_context)
|
16
|
+
threading.Thread.__init__(self, *args, **kwargs)
|
17
|
+
|
18
|
+
def _collect_parent_context(self, execution_context: Optional["ExecutionContext"] = None) -> None:
|
19
|
+
"""Collect parent thread ID, trace ID, and parent span ID from passed execution context."""
|
20
|
+
self._parent_thread = threading.get_ident()
|
21
|
+
|
22
|
+
# Only use explicitly passed execution context
|
23
|
+
if execution_context:
|
24
|
+
self._trace_id = execution_context.trace_id
|
25
|
+
self._parent_span_id = (
|
26
|
+
execution_context.parent_context.span_id
|
27
|
+
if execution_context.parent_context and hasattr(execution_context.parent_context, "span_id")
|
28
|
+
else None
|
29
|
+
)
|
30
|
+
else:
|
31
|
+
self._trace_id = None
|
32
|
+
self._parent_span_id = None
|
33
|
+
|
34
|
+
def get_parent_thread(self) -> Optional[int]:
|
35
|
+
return self._parent_thread
|
36
|
+
|
37
|
+
def get_trace_id(self) -> Optional[UUID]:
|
38
|
+
return self._trace_id
|
39
|
+
|
40
|
+
def get_parent_span_id(self) -> Optional[UUID]:
|
41
|
+
return self._parent_span_id
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import logging
|
2
|
+
from uuid import uuid4
|
3
|
+
|
4
|
+
from vellum.workflows import BaseWorkflow
|
5
|
+
from vellum.workflows.context import execution_context
|
6
|
+
from vellum.workflows.nodes.bases import BaseNode
|
7
|
+
from vellum.workflows.workflows.event_filters import all_workflow_event_filter
|
8
|
+
|
9
|
+
|
10
|
+
class StartNode(BaseNode):
|
11
|
+
pass
|
12
|
+
|
13
|
+
|
14
|
+
class TrivialWorkflow(BaseWorkflow):
|
15
|
+
graph = StartNode
|
16
|
+
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def test_basic_workflow_monitoring_context_flow():
|
22
|
+
"""Test that monitoring creates the correct workflow→node context hierarchy using streamed events.
|
23
|
+
What's missing:
|
24
|
+
- span_id from previous event mapping to parent context
|
25
|
+
"""
|
26
|
+
|
27
|
+
workflow = TrivialWorkflow()
|
28
|
+
|
29
|
+
with execution_context(trace_id=uuid4()):
|
30
|
+
events = list(workflow.stream(event_filter=all_workflow_event_filter))
|
31
|
+
|
32
|
+
# Verify workflow succeeded
|
33
|
+
assert len(events) >= 2
|
34
|
+
assert events[0].name == "workflow.execution.initiated"
|
35
|
+
assert events[-1].name == "workflow.execution.fulfilled"
|
36
|
+
|
37
|
+
# Collect all events with parent context
|
38
|
+
events_with_parent = [event for event in events if event.parent is not None]
|
39
|
+
|
40
|
+
assert len(events_with_parent) > 0, "Expected at least some events with parent context"
|
41
|
+
|
42
|
+
# Filter for node events
|
43
|
+
node_events = [event for event in events if event.name.startswith("node.")]
|
44
|
+
assert len(node_events) > 0, "Expected at least some node events"
|
45
|
+
|
46
|
+
# Verify each node event has the workflow as its parent context
|
47
|
+
for event in node_events:
|
48
|
+
assert event.parent is not None, "Node event should have parent context"
|
49
|
+
assert event.parent.type == "WORKFLOW", "Node event parent should be workflow"
|
50
|
+
assert event.parent.workflow_definition.name == "TrivialWorkflow", "Parent workflow name mismatch"
|
@@ -2,7 +2,7 @@ from uuid import UUID
|
|
2
2
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Literal, Optional, Type, Union
|
3
3
|
from typing_extensions import TypeGuard
|
4
4
|
|
5
|
-
from pydantic import field_serializer
|
5
|
+
from pydantic import SerializationInfo, field_serializer
|
6
6
|
|
7
7
|
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
8
8
|
from vellum.workflows.errors import WorkflowError
|
@@ -101,6 +101,17 @@ class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType, St
|
|
101
101
|
def initial_state(self) -> Optional[StateType]:
|
102
102
|
return self.body.initial_state
|
103
103
|
|
104
|
+
@field_serializer("body")
|
105
|
+
def serialize_body(
|
106
|
+
self, body: WorkflowExecutionInitiatedBody[InputsType, StateType], info: SerializationInfo
|
107
|
+
) -> WorkflowExecutionInitiatedBody[InputsType, StateType]:
|
108
|
+
context = info.context if info and hasattr(info, "context") else {}
|
109
|
+
if context and "event_enricher" in context and callable(context["event_enricher"]):
|
110
|
+
event = context["event_enricher"](self)
|
111
|
+
return event.body
|
112
|
+
else:
|
113
|
+
return body
|
114
|
+
|
104
115
|
|
105
116
|
class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
|
106
117
|
output: BaseOutput
|
@@ -35,4 +35,11 @@ class ContainsExpression(BaseDescriptor[bool], Generic[LHS, RHS]):
|
|
35
35
|
)
|
36
36
|
|
37
37
|
rhs = resolve_value(self._rhs, state)
|
38
|
+
|
39
|
+
if isinstance(rhs, dict):
|
40
|
+
raise InvalidExpressionException(
|
41
|
+
"Cannot use dict as right-hand side of contains operation. "
|
42
|
+
"Use dict keys/values or convert to strings for comparison."
|
43
|
+
)
|
44
|
+
|
38
45
|
return rhs in lhs
|
@@ -0,0 +1,175 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from vellum.workflows.constants import undefined
|
4
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
5
|
+
from vellum.workflows.expressions.contains import ContainsExpression
|
6
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
7
|
+
from vellum.workflows.state.base import BaseState
|
8
|
+
|
9
|
+
|
10
|
+
class TestState(BaseState):
|
11
|
+
dict_value: dict = {"key": "value"}
|
12
|
+
list_value: list = [1, 2, 3]
|
13
|
+
string_value: str = "hello world"
|
14
|
+
|
15
|
+
|
16
|
+
def test_dict_contains_dict_raises_error():
|
17
|
+
"""
|
18
|
+
Tests that ContainsExpression raises clear error for dict-contains-dict scenarios.
|
19
|
+
"""
|
20
|
+
state = TestState()
|
21
|
+
lhs_dict = {"foo": "bar"}
|
22
|
+
rhs_dict = {"foo": "bar"}
|
23
|
+
|
24
|
+
expression = ContainsExpression(lhs=lhs_dict, rhs=rhs_dict)
|
25
|
+
|
26
|
+
with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
|
27
|
+
expression.resolve(state)
|
28
|
+
|
29
|
+
|
30
|
+
def test_dict_contains_different_dict_raises_error():
|
31
|
+
"""
|
32
|
+
Tests that ContainsExpression raises clear error for different dict-contains-dict scenarios.
|
33
|
+
"""
|
34
|
+
state = TestState()
|
35
|
+
lhs_dict = {"foo": "bar"}
|
36
|
+
rhs_dict = {"hello": "world"}
|
37
|
+
|
38
|
+
expression = ContainsExpression(lhs=lhs_dict, rhs=rhs_dict)
|
39
|
+
|
40
|
+
with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
|
41
|
+
expression.resolve(state)
|
42
|
+
|
43
|
+
|
44
|
+
def test_string_contains_dict_raises_error():
|
45
|
+
"""
|
46
|
+
Tests that ContainsExpression raises clear error for string-contains-dict scenarios.
|
47
|
+
"""
|
48
|
+
state = TestState()
|
49
|
+
lhs_string = 'Response: {"status": "success"} was returned'
|
50
|
+
rhs_dict = {"status": "success"}
|
51
|
+
|
52
|
+
expression = ContainsExpression(lhs=lhs_string, rhs=rhs_dict)
|
53
|
+
|
54
|
+
with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
|
55
|
+
expression.resolve(state)
|
56
|
+
|
57
|
+
|
58
|
+
def test_nested_dict_contains_dict_raises_error():
|
59
|
+
"""
|
60
|
+
Tests that ContainsExpression raises clear error for nested dict scenarios.
|
61
|
+
"""
|
62
|
+
state = TestState()
|
63
|
+
lhs_dict = {"user": {"name": "john", "age": 30}}
|
64
|
+
rhs_dict = {"age": 30, "name": "john"}
|
65
|
+
|
66
|
+
expression = ContainsExpression(lhs=lhs_dict, rhs=rhs_dict)
|
67
|
+
|
68
|
+
with pytest.raises(InvalidExpressionException, match="Cannot use dict as right-hand side"):
|
69
|
+
expression.resolve(state)
|
70
|
+
|
71
|
+
|
72
|
+
def test_list_contains_string():
|
73
|
+
"""
|
74
|
+
Tests that ContainsExpression preserves original list functionality.
|
75
|
+
"""
|
76
|
+
state = TestState()
|
77
|
+
|
78
|
+
expression = TestState.list_value.contains(2)
|
79
|
+
result = expression.resolve(state)
|
80
|
+
|
81
|
+
assert result is True
|
82
|
+
|
83
|
+
|
84
|
+
def test_string_contains_substring():
|
85
|
+
"""
|
86
|
+
Tests that ContainsExpression preserves original string functionality.
|
87
|
+
"""
|
88
|
+
state = TestState()
|
89
|
+
|
90
|
+
expression = TestState.string_value.contains("world")
|
91
|
+
result = expression.resolve(state)
|
92
|
+
|
93
|
+
assert result is True
|
94
|
+
|
95
|
+
|
96
|
+
def test_set_contains_item():
|
97
|
+
"""
|
98
|
+
Tests that ContainsExpression works with sets.
|
99
|
+
"""
|
100
|
+
state = TestState()
|
101
|
+
lhs_set = {1, 2, 3}
|
102
|
+
rhs_item = 2
|
103
|
+
|
104
|
+
expression = ContainsExpression(lhs=lhs_set, rhs=rhs_item)
|
105
|
+
result = expression.resolve(state)
|
106
|
+
|
107
|
+
assert result is True
|
108
|
+
|
109
|
+
|
110
|
+
def test_tuple_contains_item():
|
111
|
+
"""
|
112
|
+
Tests that ContainsExpression works with tuples.
|
113
|
+
"""
|
114
|
+
state = TestState()
|
115
|
+
lhs_tuple = (1, 2, 3)
|
116
|
+
rhs_item = 2
|
117
|
+
|
118
|
+
expression = ContainsExpression(lhs=lhs_tuple, rhs=rhs_item)
|
119
|
+
result = expression.resolve(state)
|
120
|
+
|
121
|
+
assert result is True
|
122
|
+
|
123
|
+
|
124
|
+
def test_invalid_lhs_type():
|
125
|
+
"""
|
126
|
+
Tests that ContainsExpression raises exception for invalid LHS types.
|
127
|
+
"""
|
128
|
+
|
129
|
+
class NoContainsSupport:
|
130
|
+
pass
|
131
|
+
|
132
|
+
state = TestState()
|
133
|
+
no_contains_obj = NoContainsSupport()
|
134
|
+
expression = ContainsExpression(lhs=no_contains_obj, rhs="test")
|
135
|
+
|
136
|
+
with pytest.raises(
|
137
|
+
InvalidExpressionException, match="Expected a LHS that supported `contains`, got `NoContainsSupport`"
|
138
|
+
):
|
139
|
+
expression.resolve(state)
|
140
|
+
|
141
|
+
|
142
|
+
def test_undefined_lhs_returns_false():
|
143
|
+
"""
|
144
|
+
Tests that ContainsExpression returns False for undefined LHS.
|
145
|
+
"""
|
146
|
+
state = TestState()
|
147
|
+
expression = ContainsExpression(lhs=undefined, rhs="test")
|
148
|
+
|
149
|
+
result = expression.resolve(state)
|
150
|
+
|
151
|
+
assert result is False
|
152
|
+
|
153
|
+
|
154
|
+
def test_contains_with_constant_value_reference():
|
155
|
+
"""
|
156
|
+
Tests ContainsExpression with ConstantValueReference for valid operations.
|
157
|
+
"""
|
158
|
+
state = TestState()
|
159
|
+
lhs_ref = ConstantValueReference([1, 2, 3])
|
160
|
+
rhs_ref = ConstantValueReference(2)
|
161
|
+
|
162
|
+
expression: ContainsExpression = ContainsExpression(lhs=lhs_ref, rhs=rhs_ref)
|
163
|
+
result = expression.resolve(state)
|
164
|
+
|
165
|
+
assert result is True
|
166
|
+
|
167
|
+
|
168
|
+
def test_expression_metadata():
|
169
|
+
"""
|
170
|
+
Tests that ContainsExpression has correct name and types properties.
|
171
|
+
"""
|
172
|
+
expression = ContainsExpression(lhs=[1, 2, 3], rhs=2)
|
173
|
+
|
174
|
+
assert expression.name == "[1, 2, 3] contains 2"
|
175
|
+
assert expression.types == (bool,)
|