vellum-ai 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/types/integration_name.py +1 -0
- vellum/workflows/expressions/concat.py +6 -3
- vellum/workflows/expressions/tests/test_concat.py +63 -8
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +20 -5
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +11 -7
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +42 -0
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +84 -56
- vellum/workflows/nodes/experimental/__init__.py +1 -3
- vellum/workflows/runner/runner.py +144 -0
- vellum/workflows/state/context.py +59 -7
- vellum/workflows/workflows/base.py +17 -0
- vellum/workflows/workflows/event_filters.py +13 -0
- vellum/workflows/workflows/tests/test_event_filters.py +126 -0
- {vellum_ai-1.8.2.dist-info → vellum_ai-1.8.3.dist-info}/METADATA +1 -1
- {vellum_ai-1.8.2.dist-info → vellum_ai-1.8.3.dist-info}/RECORD +19 -20
- vellum/workflows/nodes/experimental/openai_chat_completion_node/__init__.py +0 -5
- vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +0 -266
- {vellum_ai-1.8.2.dist-info → vellum_ai-1.8.3.dist-info}/LICENSE +0 -0
- {vellum_ai-1.8.2.dist-info → vellum_ai-1.8.3.dist-info}/WHEEL +0 -0
- {vellum_ai-1.8.2.dist-info → vellum_ai-1.8.3.dist-info}/entry_points.txt +0 -0
|
@@ -27,10 +27,10 @@ class BaseClientWrapper:
|
|
|
27
27
|
|
|
28
28
|
def get_headers(self) -> typing.Dict[str, str]:
|
|
29
29
|
headers: typing.Dict[str, str] = {
|
|
30
|
-
"User-Agent": "vellum-ai/1.8.
|
|
30
|
+
"User-Agent": "vellum-ai/1.8.3",
|
|
31
31
|
"X-Fern-Language": "Python",
|
|
32
32
|
"X-Fern-SDK-Name": "vellum-ai",
|
|
33
|
-
"X-Fern-SDK-Version": "1.8.
|
|
33
|
+
"X-Fern-SDK-Version": "1.8.3",
|
|
34
34
|
**(self.get_custom_headers() or {}),
|
|
35
35
|
}
|
|
36
36
|
if self._api_version is not None:
|
|
@@ -26,7 +26,10 @@ class ConcatExpression(BaseDescriptor[list], Generic[LHS, RHS]):
|
|
|
26
26
|
|
|
27
27
|
if not isinstance(lval, list):
|
|
28
28
|
raise InvalidExpressionException(f"Expected LHS to be a list, got {type(lval)}")
|
|
29
|
-
if not isinstance(rval, list):
|
|
30
|
-
raise InvalidExpressionException(f"Expected RHS to be a list, got {type(rval)}")
|
|
31
29
|
|
|
32
|
-
|
|
30
|
+
# If RHS is a list, concatenate normally
|
|
31
|
+
if isinstance(rval, list):
|
|
32
|
+
return lval + rval
|
|
33
|
+
# If RHS is not a list, append it as a single item
|
|
34
|
+
else:
|
|
35
|
+
return lval + [rval]
|
|
@@ -38,16 +38,71 @@ def test_concat_expression_lhs_fail():
|
|
|
38
38
|
assert "Expected LHS to be a list, got <class 'int'>" in str(exc_info.value)
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def
|
|
42
|
-
# GIVEN a list
|
|
41
|
+
def test_concat_expression_with_single_item():
|
|
42
|
+
# GIVEN a list and a single item
|
|
43
43
|
state = TestState()
|
|
44
44
|
lhs_ref = ConstantValueReference([1, 2, 3])
|
|
45
|
-
rhs_ref = ConstantValueReference(
|
|
45
|
+
rhs_ref = ConstantValueReference(4)
|
|
46
46
|
concat_expr = lhs_ref.concat(rhs_ref)
|
|
47
47
|
|
|
48
|
-
# WHEN we
|
|
49
|
-
|
|
50
|
-
concat_expr.resolve(state)
|
|
48
|
+
# WHEN we resolve the expression
|
|
49
|
+
result = concat_expr.resolve(state)
|
|
51
50
|
|
|
52
|
-
# THEN
|
|
53
|
-
assert
|
|
51
|
+
# THEN the item should be appended to the list
|
|
52
|
+
assert result == [1, 2, 3, 4]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_concat_expression_with_string():
|
|
56
|
+
# GIVEN a list and a string item
|
|
57
|
+
state = TestState()
|
|
58
|
+
lhs_ref = ConstantValueReference(["hello", "world"])
|
|
59
|
+
rhs_ref = ConstantValueReference("!")
|
|
60
|
+
concat_expr = lhs_ref.concat(rhs_ref)
|
|
61
|
+
|
|
62
|
+
# WHEN we resolve the expression
|
|
63
|
+
result = concat_expr.resolve(state)
|
|
64
|
+
|
|
65
|
+
# THEN the string should be appended to the list
|
|
66
|
+
assert result == ["hello", "world", "!"]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_concat_expression_with_none():
|
|
70
|
+
# GIVEN a list and None item
|
|
71
|
+
state = TestState()
|
|
72
|
+
lhs_ref = ConstantValueReference([1, 2])
|
|
73
|
+
rhs_ref = ConstantValueReference(None)
|
|
74
|
+
concat_expr = lhs_ref.concat(rhs_ref)
|
|
75
|
+
|
|
76
|
+
# WHEN we resolve the expression
|
|
77
|
+
result = concat_expr.resolve(state)
|
|
78
|
+
|
|
79
|
+
# THEN None should be appended to the list
|
|
80
|
+
assert result == [1, 2, None]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_concat_expression_with_list_item():
|
|
84
|
+
# GIVEN a list and another list
|
|
85
|
+
state = TestState()
|
|
86
|
+
lhs_ref = ConstantValueReference([1, 2])
|
|
87
|
+
rhs_ref = ConstantValueReference([3, 4])
|
|
88
|
+
concat_expr = lhs_ref.concat(rhs_ref)
|
|
89
|
+
|
|
90
|
+
# WHEN we resolve the expression
|
|
91
|
+
result = concat_expr.resolve(state)
|
|
92
|
+
|
|
93
|
+
# THEN the lists should be concatenated normally
|
|
94
|
+
assert result == [1, 2, 3, 4]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def test_concat_expression_empty_list():
|
|
98
|
+
# GIVEN an empty list and an item
|
|
99
|
+
state = TestState()
|
|
100
|
+
lhs_ref: ConstantValueReference[list] = ConstantValueReference([])
|
|
101
|
+
rhs_ref = ConstantValueReference("first")
|
|
102
|
+
concat_expr = lhs_ref.concat(rhs_ref)
|
|
103
|
+
|
|
104
|
+
# WHEN we resolve the expression
|
|
105
|
+
result = concat_expr.resolve(state)
|
|
106
|
+
|
|
107
|
+
# THEN the item should be appended to the empty list
|
|
108
|
+
assert result == ["first"]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from itertools import chain
|
|
3
|
+
from typing import ClassVar, Generator, Generic, Iterator, List, Optional, Union, cast
|
|
3
4
|
|
|
4
5
|
from vellum import AdHocExecutePromptEvent, ExecutePromptEvent, PromptOutput
|
|
5
6
|
from vellum.client.core import RequestOptions
|
|
@@ -60,15 +61,26 @@ class BasePromptNode(BaseNode[StateType], Generic[StateType]):
|
|
|
60
61
|
except ApiError as e:
|
|
61
62
|
self._handle_api_error(e)
|
|
62
63
|
|
|
63
|
-
# We don't use the INITIATED event anyway, so we can just skip it
|
|
64
|
-
# and use the exception handling to catch other api level errors
|
|
65
64
|
try:
|
|
66
|
-
next(prompt_event_stream)
|
|
65
|
+
first_event = next(prompt_event_stream)
|
|
67
66
|
except ApiError as e:
|
|
68
67
|
self._handle_api_error(e)
|
|
68
|
+
else:
|
|
69
|
+
if first_event.state == "REJECTED":
|
|
70
|
+
workflow_error = vellum_error_to_workflow_error(first_event.error)
|
|
71
|
+
raise NodeException.of(workflow_error)
|
|
72
|
+
if first_event.state != "INITIATED":
|
|
73
|
+
prompt_event_stream = cast(
|
|
74
|
+
Union[Iterator[AdHocExecutePromptEvent], Iterator[ExecutePromptEvent]],
|
|
75
|
+
chain([first_event], prompt_event_stream),
|
|
76
|
+
)
|
|
69
77
|
|
|
70
78
|
outputs: Optional[List[PromptOutput]] = None
|
|
79
|
+
exception: Optional[NodeException] = None
|
|
71
80
|
for event in prompt_event_stream:
|
|
81
|
+
if exception:
|
|
82
|
+
continue
|
|
83
|
+
|
|
72
84
|
if event.state == "INITIATED":
|
|
73
85
|
continue
|
|
74
86
|
elif event.state == "STREAMING":
|
|
@@ -78,7 +90,10 @@ class BasePromptNode(BaseNode[StateType], Generic[StateType]):
|
|
|
78
90
|
yield BaseOutput(name="results", value=event.outputs)
|
|
79
91
|
elif event.state == "REJECTED":
|
|
80
92
|
workflow_error = vellum_error_to_workflow_error(event.error)
|
|
81
|
-
|
|
93
|
+
exception = NodeException.of(workflow_error)
|
|
94
|
+
|
|
95
|
+
if exception:
|
|
96
|
+
raise exception
|
|
82
97
|
|
|
83
98
|
return outputs
|
|
84
99
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from itertools import chain
|
|
1
2
|
import json
|
|
2
3
|
from uuid import uuid4
|
|
3
4
|
from typing import Callable, ClassVar, Generator, Generic, Iterator, List, Optional, Set, Tuple, Union
|
|
@@ -225,13 +226,16 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
|
225
226
|
except ApiError as e:
|
|
226
227
|
self._handle_api_error(e)
|
|
227
228
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
229
|
+
try:
|
|
230
|
+
first_event = next(prompt_event_stream)
|
|
231
|
+
except ApiError as e:
|
|
232
|
+
self._handle_api_error(e)
|
|
233
|
+
else:
|
|
234
|
+
if first_event.state == "REJECTED":
|
|
235
|
+
workflow_error = vellum_error_to_workflow_error(first_event.error)
|
|
236
|
+
raise NodeException.of(workflow_error)
|
|
237
|
+
if first_event.state != "INITIATED":
|
|
238
|
+
prompt_event_stream = chain([first_event], prompt_event_stream)
|
|
235
239
|
|
|
236
240
|
outputs: Optional[List[PromptOutput]] = None
|
|
237
241
|
for event in prompt_event_stream:
|
vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py
CHANGED
|
@@ -25,6 +25,7 @@ from vellum import (
|
|
|
25
25
|
PromptRequestStringInput,
|
|
26
26
|
PromptRequestVideoInput,
|
|
27
27
|
PromptSettings,
|
|
28
|
+
RejectedExecutePromptEvent,
|
|
28
29
|
RichTextPromptBlock,
|
|
29
30
|
StringVellumValue,
|
|
30
31
|
VariablePromptBlock,
|
|
@@ -32,6 +33,7 @@ from vellum import (
|
|
|
32
33
|
VellumAudioRequest,
|
|
33
34
|
VellumDocument,
|
|
34
35
|
VellumDocumentRequest,
|
|
36
|
+
VellumError,
|
|
35
37
|
VellumImage,
|
|
36
38
|
VellumImageRequest,
|
|
37
39
|
VellumVideo,
|
|
@@ -896,3 +898,43 @@ def test_inline_prompt_node__json_output_with_markdown_code_blocks(vellum_adhoc_
|
|
|
896
898
|
json_output = outputs[2]
|
|
897
899
|
assert json_output.name == "json"
|
|
898
900
|
assert json_output.value == expected_json
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
def test_inline_prompt_node__provider_error_from_api(vellum_adhoc_prompt_client):
|
|
904
|
+
"""
|
|
905
|
+
Tests that InlinePromptNode raises NodeException with PROVIDER_ERROR code when first event is REJECTED.
|
|
906
|
+
"""
|
|
907
|
+
|
|
908
|
+
# GIVEN an InlinePromptNode with basic configuration
|
|
909
|
+
class TestNode(InlinePromptNode):
|
|
910
|
+
ml_model = "test-model"
|
|
911
|
+
blocks = []
|
|
912
|
+
prompt_inputs = {}
|
|
913
|
+
|
|
914
|
+
# AND the API returns a REJECTED event as the first event with a provider error
|
|
915
|
+
provider_error = VellumError(
|
|
916
|
+
code="PROVIDER_ERROR",
|
|
917
|
+
message="Provider rate limit exceeded",
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
|
921
|
+
execution_id = str(uuid4())
|
|
922
|
+
events: List[ExecutePromptEvent] = [
|
|
923
|
+
RejectedExecutePromptEvent(
|
|
924
|
+
execution_id=execution_id,
|
|
925
|
+
error=provider_error,
|
|
926
|
+
),
|
|
927
|
+
]
|
|
928
|
+
yield from events
|
|
929
|
+
|
|
930
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
|
931
|
+
|
|
932
|
+
# WHEN the node is run
|
|
933
|
+
node = TestNode()
|
|
934
|
+
|
|
935
|
+
# THEN it should raise a NodeException with PROVIDER_ERROR error code
|
|
936
|
+
with pytest.raises(NodeException) as excinfo:
|
|
937
|
+
list(node.run())
|
|
938
|
+
|
|
939
|
+
# AND the exception should have the correct error code
|
|
940
|
+
assert excinfo.value.code == WorkflowErrorCode.PROVIDER_ERROR
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from uuid import UUID
|
|
3
|
-
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Union, cast
|
|
2
|
+
from uuid import UUID, uuid4
|
|
3
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Type, Union, cast
|
|
4
4
|
|
|
5
5
|
from vellum import (
|
|
6
6
|
ChatMessage,
|
|
@@ -19,12 +19,13 @@ from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT, undefined
|
|
|
19
19
|
from vellum.workflows.context import execution_context, get_execution_context, get_parent_context
|
|
20
20
|
from vellum.workflows.errors import WorkflowErrorCode
|
|
21
21
|
from vellum.workflows.errors.types import workflow_event_error_to_workflow_error
|
|
22
|
-
from vellum.workflows.events.types import default_serializer
|
|
22
|
+
from vellum.workflows.events.types import WorkflowDeploymentParentContext, default_serializer
|
|
23
23
|
from vellum.workflows.events.workflow import is_workflow_event
|
|
24
24
|
from vellum.workflows.exceptions import NodeException, WorkflowInitializationException
|
|
25
25
|
from vellum.workflows.inputs.base import BaseInputs
|
|
26
26
|
from vellum.workflows.nodes.bases.base import BaseNode
|
|
27
27
|
from vellum.workflows.outputs.base import BaseOutput
|
|
28
|
+
from vellum.workflows.state.context import WorkflowContext, WorkflowDeploymentMetadata
|
|
28
29
|
from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
|
|
29
30
|
from vellum.workflows.types.generics import StateType
|
|
30
31
|
from vellum.workflows.workflows.event_filters import all_workflow_event_filter
|
|
@@ -155,69 +156,95 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
|
155
156
|
filtered_inputs = {k: v for k, v in self.subworkflow_inputs.items() if v is not None}
|
|
156
157
|
return inputs_class(**filtered_inputs)
|
|
157
158
|
|
|
158
|
-
def _run_resolved_workflow(
|
|
159
|
+
def _run_resolved_workflow(
|
|
160
|
+
self,
|
|
161
|
+
workflow_class: Type["BaseWorkflow"],
|
|
162
|
+
deployment_metadata: Optional[WorkflowDeploymentMetadata],
|
|
163
|
+
) -> Iterator[BaseOutput]:
|
|
159
164
|
"""Execute resolved workflow directly (similar to InlineSubworkflowNode)."""
|
|
160
|
-
|
|
165
|
+
# Construct the parent context hierarchy for the subworkflow
|
|
166
|
+
parent_context = get_parent_context()
|
|
167
|
+
|
|
168
|
+
# If we have deployment metadata, wrap the parent context with WorkflowDeploymentParentContext
|
|
169
|
+
if deployment_metadata:
|
|
170
|
+
parent_context = WorkflowDeploymentParentContext(
|
|
171
|
+
span_id=uuid4(),
|
|
172
|
+
deployment_id=deployment_metadata.deployment_id,
|
|
173
|
+
deployment_name=deployment_metadata.deployment_name,
|
|
174
|
+
deployment_history_item_id=deployment_metadata.deployment_history_item_id,
|
|
175
|
+
release_tag_id=deployment_metadata.release_tag_id,
|
|
176
|
+
release_tag_name=deployment_metadata.release_tag_name,
|
|
177
|
+
workflow_version_id=deployment_metadata.workflow_version_id,
|
|
178
|
+
external_id=self.external_id if self.external_id is not OMIT else None,
|
|
179
|
+
metadata=self.metadata if self.metadata is not OMIT else None,
|
|
180
|
+
parent=parent_context,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
with execution_context(parent_context=parent_context):
|
|
184
|
+
# Instantiate the workflow inside the execution context so it captures the correct parent context
|
|
185
|
+
resolved_workflow = workflow_class(
|
|
186
|
+
context=WorkflowContext.create_from(self._context), parent_state=self.state
|
|
187
|
+
)
|
|
161
188
|
subworkflow_stream = resolved_workflow.stream(
|
|
162
189
|
inputs=self._compile_subworkflow_inputs_for_direct_invocation(resolved_workflow),
|
|
163
190
|
event_filter=all_workflow_event_filter,
|
|
164
191
|
node_output_mocks=self._context._get_all_node_output_mocks(),
|
|
165
192
|
)
|
|
166
193
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
194
|
+
try:
|
|
195
|
+
first_event = next(subworkflow_stream)
|
|
196
|
+
self._context._emit_subworkflow_event(first_event)
|
|
197
|
+
except WorkflowInitializationException as e:
|
|
198
|
+
hashed_module = e.definition.__module__
|
|
199
|
+
raise NodeException(
|
|
200
|
+
message=e.message,
|
|
201
|
+
code=e.code,
|
|
202
|
+
raw_data={"hashed_module": hashed_module},
|
|
203
|
+
) from e
|
|
204
|
+
|
|
205
|
+
outputs = None
|
|
206
|
+
exception = None
|
|
207
|
+
fulfilled_output_names: Set[str] = set()
|
|
208
|
+
|
|
209
|
+
for event in subworkflow_stream:
|
|
210
|
+
self._context._emit_subworkflow_event(event)
|
|
211
|
+
if exception:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
if not is_workflow_event(event):
|
|
215
|
+
continue
|
|
216
|
+
if event.workflow_definition != resolved_workflow.__class__:
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
if event.name == "workflow.execution.streaming":
|
|
220
|
+
if event.output.is_fulfilled:
|
|
221
|
+
fulfilled_output_names.add(event.output.name)
|
|
222
|
+
yield event.output
|
|
223
|
+
elif event.name == "workflow.execution.fulfilled":
|
|
224
|
+
outputs = event.outputs
|
|
225
|
+
elif event.name == "workflow.execution.rejected":
|
|
226
|
+
exception = NodeException.of(event.error)
|
|
227
|
+
elif event.name == "workflow.execution.paused":
|
|
228
|
+
exception = NodeException(
|
|
229
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
|
230
|
+
message="Subworkflow unexpectedly paused",
|
|
231
|
+
)
|
|
181
232
|
|
|
182
|
-
for event in subworkflow_stream:
|
|
183
|
-
self._context._emit_subworkflow_event(event)
|
|
184
233
|
if exception:
|
|
185
|
-
|
|
234
|
+
raise exception
|
|
186
235
|
|
|
187
|
-
if
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
continue
|
|
191
|
-
|
|
192
|
-
if event.name == "workflow.execution.streaming":
|
|
193
|
-
if event.output.is_fulfilled:
|
|
194
|
-
fulfilled_output_names.add(event.output.name)
|
|
195
|
-
yield event.output
|
|
196
|
-
elif event.name == "workflow.execution.fulfilled":
|
|
197
|
-
outputs = event.outputs
|
|
198
|
-
elif event.name == "workflow.execution.rejected":
|
|
199
|
-
exception = NodeException.of(event.error)
|
|
200
|
-
elif event.name == "workflow.execution.paused":
|
|
201
|
-
exception = NodeException(
|
|
236
|
+
if outputs is None:
|
|
237
|
+
raise NodeException(
|
|
238
|
+
message="Expected to receive outputs from Workflow Deployment",
|
|
202
239
|
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
|
203
|
-
message="Subworkflow unexpectedly paused",
|
|
204
240
|
)
|
|
205
241
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
for output_descriptor, output_value in outputs:
|
|
216
|
-
if output_descriptor.name not in fulfilled_output_names:
|
|
217
|
-
yield BaseOutput(
|
|
218
|
-
name=output_descriptor.name,
|
|
219
|
-
value=output_value,
|
|
220
|
-
)
|
|
242
|
+
for output_descriptor, output_value in outputs:
|
|
243
|
+
if output_descriptor.name not in fulfilled_output_names:
|
|
244
|
+
yield BaseOutput(
|
|
245
|
+
name=output_descriptor.name,
|
|
246
|
+
value=output_value,
|
|
247
|
+
)
|
|
221
248
|
|
|
222
249
|
def run(self) -> Iterator[BaseOutput]:
|
|
223
250
|
execution_context = get_execution_context()
|
|
@@ -248,11 +275,12 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
|
248
275
|
message="Expected deployment name to be provided for subworkflow execution.",
|
|
249
276
|
)
|
|
250
277
|
|
|
251
|
-
|
|
278
|
+
resolved_result = self._context.resolve_workflow_deployment(
|
|
252
279
|
deployment_name=deployment_name, release_tag=self.release_tag, state=self.state
|
|
253
280
|
)
|
|
254
|
-
if
|
|
255
|
-
|
|
281
|
+
if resolved_result:
|
|
282
|
+
workflow_class, deployment_metadata = resolved_result
|
|
283
|
+
yield from self._run_resolved_workflow(workflow_class, deployment_metadata)
|
|
256
284
|
return
|
|
257
285
|
|
|
258
286
|
try:
|
|
@@ -74,6 +74,9 @@ from vellum.workflows.references import ExternalInputReference, OutputReference
|
|
|
74
74
|
from vellum.workflows.references.state_value import StateValueReference
|
|
75
75
|
from vellum.workflows.state.base import BaseState
|
|
76
76
|
from vellum.workflows.state.delta import StateDelta
|
|
77
|
+
from vellum.workflows.triggers.base import BaseTrigger
|
|
78
|
+
from vellum.workflows.triggers.integration import IntegrationTrigger
|
|
79
|
+
from vellum.workflows.triggers.manual import ManualTrigger
|
|
77
80
|
from vellum.workflows.types.core import CancelSignal
|
|
78
81
|
from vellum.workflows.types.generics import InputsType, OutputsType, StateType
|
|
79
82
|
|
|
@@ -109,6 +112,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
|
109
112
|
max_concurrency: Optional[int] = None,
|
|
110
113
|
timeout: Optional[float] = None,
|
|
111
114
|
init_execution_context: Optional[ExecutionContext] = None,
|
|
115
|
+
trigger: Optional[BaseTrigger] = None,
|
|
112
116
|
):
|
|
113
117
|
if state and external_inputs:
|
|
114
118
|
raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
|
|
@@ -198,7 +202,24 @@ class WorkflowRunner(Generic[StateType]):
|
|
|
198
202
|
)
|
|
199
203
|
|
|
200
204
|
self._entrypoints = self.workflow.get_entrypoints()
|
|
205
|
+
elif trigger:
|
|
206
|
+
# When trigger is provided, set up default state and filter entrypoints by trigger type
|
|
207
|
+
normalized_inputs = deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
|
|
208
|
+
if state:
|
|
209
|
+
self._initial_state = deepcopy(state)
|
|
210
|
+
self._initial_state.meta.workflow_inputs = normalized_inputs
|
|
211
|
+
self._initial_state.meta.span_id = uuid4()
|
|
212
|
+
self._initial_state.meta.workflow_definition = self.workflow.__class__
|
|
213
|
+
else:
|
|
214
|
+
self._initial_state = self.workflow.get_default_state(normalized_inputs)
|
|
215
|
+
self._should_emit_initial_state = False
|
|
216
|
+
|
|
217
|
+
# Validate and bind trigger, then filter entrypoints
|
|
218
|
+
self._validate_and_bind_trigger(trigger)
|
|
219
|
+
self._entrypoints = self.workflow.get_entrypoints()
|
|
220
|
+
self._filter_entrypoints_for_trigger(trigger)
|
|
201
221
|
else:
|
|
222
|
+
# Default case: no entrypoint overrides and no trigger
|
|
202
223
|
normalized_inputs = deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
|
|
203
224
|
if state:
|
|
204
225
|
self._initial_state = deepcopy(state)
|
|
@@ -213,6 +234,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
|
213
234
|
self._should_emit_initial_state = False
|
|
214
235
|
self._entrypoints = self.workflow.get_entrypoints()
|
|
215
236
|
|
|
237
|
+
# Check if workflow requires a trigger but none was provided
|
|
238
|
+
self._validate_no_trigger_provided()
|
|
239
|
+
|
|
216
240
|
# This queue is responsible for sending events from WorkflowRunner to the outside world
|
|
217
241
|
self._workflow_event_outer_queue: Queue[WorkflowEvent] = Queue()
|
|
218
242
|
|
|
@@ -250,6 +274,126 @@ class WorkflowRunner(Generic[StateType]):
|
|
|
250
274
|
self._cancel_thread: Optional[Thread] = None
|
|
251
275
|
self._timeout_thread: Optional[Thread] = None
|
|
252
276
|
|
|
277
|
+
def _has_manual_trigger(self) -> bool:
|
|
278
|
+
"""Check if workflow has ManualTrigger."""
|
|
279
|
+
for subgraph in self.workflow.get_subgraphs():
|
|
280
|
+
for trigger in subgraph.triggers:
|
|
281
|
+
if issubclass(trigger, ManualTrigger):
|
|
282
|
+
return True
|
|
283
|
+
return False
|
|
284
|
+
|
|
285
|
+
def _get_entrypoints_for_trigger_type(self, trigger_class: Type) -> List[Type[BaseNode]]:
|
|
286
|
+
"""Get all entrypoints connected to a specific trigger type.
|
|
287
|
+
|
|
288
|
+
Allows subclasses: if trigger_class is a subclass of any declared trigger,
|
|
289
|
+
returns those entrypoints.
|
|
290
|
+
"""
|
|
291
|
+
entrypoints: List[Type[BaseNode]] = []
|
|
292
|
+
for subgraph in self.workflow.get_subgraphs():
|
|
293
|
+
for trigger in subgraph.triggers:
|
|
294
|
+
# Check if the provided trigger_class is a subclass of the declared trigger
|
|
295
|
+
# This allows runtime instances to be subclasses of what's declared in the workflow
|
|
296
|
+
if issubclass(trigger_class, trigger):
|
|
297
|
+
entrypoints.extend(subgraph.entrypoints)
|
|
298
|
+
return entrypoints
|
|
299
|
+
|
|
300
|
+
def _validate_and_bind_trigger(self, trigger: BaseTrigger) -> None:
|
|
301
|
+
"""
|
|
302
|
+
Validate that trigger is compatible with workflow and bind it to state.
|
|
303
|
+
|
|
304
|
+
Supports all trigger types derived from BaseTrigger:
|
|
305
|
+
- IntegrationTrigger instances (Slack, Gmail, etc.)
|
|
306
|
+
- ManualTrigger instances (explicit manual execution)
|
|
307
|
+
- ScheduledTrigger instances (time-based triggers)
|
|
308
|
+
- Any future trigger types
|
|
309
|
+
|
|
310
|
+
Raises:
|
|
311
|
+
WorkflowInitializationException: If trigger type is not compatible with workflow
|
|
312
|
+
"""
|
|
313
|
+
trigger_class = type(trigger)
|
|
314
|
+
|
|
315
|
+
# Search for a compatible trigger type in the workflow
|
|
316
|
+
found_compatible_trigger = False
|
|
317
|
+
has_any_triggers = False
|
|
318
|
+
incompatible_trigger_names: List[str] = []
|
|
319
|
+
|
|
320
|
+
for subgraph in self.workflow.get_subgraphs():
|
|
321
|
+
for declared_trigger in subgraph.triggers:
|
|
322
|
+
has_any_triggers = True
|
|
323
|
+
# Allow subclasses: if workflow declares BaseSlackTrigger, accept SpecificSlackTrigger instances
|
|
324
|
+
if issubclass(trigger_class, declared_trigger):
|
|
325
|
+
found_compatible_trigger = True
|
|
326
|
+
break
|
|
327
|
+
else:
|
|
328
|
+
incompatible_trigger_names.append(declared_trigger.__name__)
|
|
329
|
+
|
|
330
|
+
if found_compatible_trigger:
|
|
331
|
+
break
|
|
332
|
+
|
|
333
|
+
# Special case: workflows with no explicit triggers implicitly support ManualTrigger
|
|
334
|
+
if not has_any_triggers and not isinstance(trigger, ManualTrigger):
|
|
335
|
+
raise WorkflowInitializationException(
|
|
336
|
+
message=f"Provided trigger type {trigger_class.__name__} is not compatible with workflow. "
|
|
337
|
+
f"Workflow has no explicit triggers and only supports ManualTrigger.",
|
|
338
|
+
workflow_definition=self.workflow.__class__,
|
|
339
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Validate that we found a compatible trigger type
|
|
343
|
+
if has_any_triggers and not found_compatible_trigger:
|
|
344
|
+
raise WorkflowInitializationException(
|
|
345
|
+
message=f"Provided trigger type {trigger_class.__name__} is not compatible with workflow triggers. "
|
|
346
|
+
f"Workflow has: {sorted(set(incompatible_trigger_names))}",
|
|
347
|
+
workflow_definition=self.workflow.__class__,
|
|
348
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# Bind trigger to state (works for all trigger types via BaseTrigger.bind_to_state)
|
|
352
|
+
trigger.bind_to_state(self._initial_state)
|
|
353
|
+
|
|
354
|
+
def _filter_entrypoints_for_trigger(self, trigger: BaseTrigger) -> None:
|
|
355
|
+
"""
|
|
356
|
+
Filter entrypoints to those connected to the specific trigger type.
|
|
357
|
+
|
|
358
|
+
Uses the specific trigger subclass, not the parent class, allowing workflows
|
|
359
|
+
with multiple triggers to route to the correct path.
|
|
360
|
+
"""
|
|
361
|
+
trigger_class = type(trigger)
|
|
362
|
+
specific_entrypoints = self._get_entrypoints_for_trigger_type(trigger_class)
|
|
363
|
+
if specific_entrypoints:
|
|
364
|
+
self._entrypoints = specific_entrypoints
|
|
365
|
+
|
|
366
|
+
def _validate_no_trigger_provided(self) -> None:
|
|
367
|
+
"""
|
|
368
|
+
Validate that workflow can run without a trigger.
|
|
369
|
+
|
|
370
|
+
If workflow has IntegrationTrigger(s) but no ManualTrigger, it requires a trigger instance.
|
|
371
|
+
If workflow has both, filter entrypoints to ManualTrigger path only.
|
|
372
|
+
|
|
373
|
+
Raises:
|
|
374
|
+
WorkflowInitializationException: If workflow requires trigger but none was provided
|
|
375
|
+
"""
|
|
376
|
+
# Collect all IntegrationTrigger types in the workflow
|
|
377
|
+
workflow_integration_triggers = []
|
|
378
|
+
for subgraph in self.workflow.get_subgraphs():
|
|
379
|
+
for trigger_type in subgraph.triggers:
|
|
380
|
+
if issubclass(trigger_type, IntegrationTrigger):
|
|
381
|
+
workflow_integration_triggers.append(trigger_type)
|
|
382
|
+
|
|
383
|
+
if workflow_integration_triggers:
|
|
384
|
+
if not self._has_manual_trigger():
|
|
385
|
+
# Workflow has ONLY IntegrationTrigger - this is an error
|
|
386
|
+
raise WorkflowInitializationException(
|
|
387
|
+
message="Workflow has IntegrationTrigger which requires trigger parameter",
|
|
388
|
+
workflow_definition=self.workflow.__class__,
|
|
389
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Workflow has both IntegrationTrigger and ManualTrigger - filter to ManualTrigger path
|
|
393
|
+
manual_entrypoints = self._get_entrypoints_for_trigger_type(ManualTrigger)
|
|
394
|
+
if manual_entrypoints:
|
|
395
|
+
self._entrypoints = manual_entrypoints
|
|
396
|
+
|
|
253
397
|
@contextmanager
|
|
254
398
|
def _httpx_logger_with_span_id(self) -> Iterator[None]:
|
|
255
399
|
"""
|