vellum-ai 0.14.5__py3-none-any.whl → 0.14.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +18 -0
- vellum/client/__init__.py +8 -8
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/resources/__init__.py +2 -0
- vellum/client/resources/workflow_sandboxes/__init__.py +3 -0
- vellum/client/resources/workflow_sandboxes/client.py +146 -0
- vellum/client/resources/workflow_sandboxes/types/__init__.py +5 -0
- vellum/client/resources/workflow_sandboxes/types/list_workflow_sandbox_examples_request_tag.py +5 -0
- vellum/client/types/__init__.py +16 -0
- vellum/client/types/array_chat_message_content_item.py +6 -1
- vellum/client/types/array_chat_message_content_item_request.py +2 -0
- vellum/client/types/chat_message_content.py +2 -0
- vellum/client/types/chat_message_content_request.py +2 -0
- vellum/client/types/document_chat_message_content.py +25 -0
- vellum/client/types/document_chat_message_content_request.py +25 -0
- vellum/client/types/document_vellum_value.py +25 -0
- vellum/client/types/document_vellum_value_request.py +25 -0
- vellum/client/types/paginated_workflow_sandbox_example_list.py +23 -0
- vellum/client/types/vellum_document.py +20 -0
- vellum/client/types/vellum_document_request.py +20 -0
- vellum/client/types/vellum_value.py +2 -0
- vellum/client/types/vellum_value_request.py +2 -0
- vellum/client/types/vellum_variable_type.py +1 -0
- vellum/client/types/workflow_sandbox_example.py +22 -0
- vellum/resources/workflow_sandboxes/types/__init__.py +3 -0
- vellum/resources/workflow_sandboxes/types/list_workflow_sandbox_examples_request_tag.py +3 -0
- vellum/types/document_chat_message_content.py +3 -0
- vellum/types/document_chat_message_content_request.py +3 -0
- vellum/types/document_vellum_value.py +3 -0
- vellum/types/document_vellum_value_request.py +3 -0
- vellum/types/paginated_workflow_sandbox_example_list.py +3 -0
- vellum/types/vellum_document.py +3 -0
- vellum/types/vellum_document_request.py +3 -0
- vellum/types/workflow_sandbox_example.py +3 -0
- vellum/workflows/exceptions.py +18 -0
- vellum/workflows/inputs/base.py +27 -1
- vellum/workflows/inputs/tests/__init__.py +0 -0
- vellum/workflows/inputs/tests/test_inputs.py +49 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -1
- vellum/workflows/nodes/core/map_node/node.py +7 -7
- vellum/workflows/nodes/core/try_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +2 -2
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +5 -4
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +4 -4
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +49 -15
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +165 -0
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +3 -1
- vellum/workflows/outputs/base.py +1 -1
- vellum/workflows/runner/runner.py +16 -10
- vellum/workflows/state/context.py +7 -7
- vellum/workflows/workflows/base.py +61 -59
- vellum/workflows/workflows/tests/test_base_workflow.py +131 -40
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/RECORD +68 -44
- vellum_cli/__init__.py +36 -0
- vellum_cli/init.py +128 -0
- vellum_cli/pull.py +6 -3
- vellum_cli/tests/test_init.py +355 -0
- vellum_cli/tests/test_pull.py +127 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +4 -4
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +31 -0
- vellum_ee/workflows/display/nodes/vellum/utils.py +8 -0
- vellum_ee/workflows/display/vellum.py +0 -4
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +29 -0
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/entry_points.txt +0 -0
@@ -3,6 +3,7 @@ from datetime import datetime
|
|
3
3
|
from uuid import uuid4
|
4
4
|
from typing import Any, Iterator, List
|
5
5
|
|
6
|
+
from vellum.client.core.api_error import ApiError
|
6
7
|
from vellum.client.types.chat_message import ChatMessage
|
7
8
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
8
9
|
from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
|
@@ -11,6 +12,8 @@ from vellum.client.types.workflow_request_chat_history_input_request import Work
|
|
11
12
|
from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
|
12
13
|
from vellum.client.types.workflow_result_event import WorkflowResultEvent
|
13
14
|
from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
|
15
|
+
from vellum.workflows.errors import WorkflowErrorCode
|
16
|
+
from vellum.workflows.exceptions import NodeException
|
14
17
|
from vellum.workflows.nodes.displayable.subworkflow_deployment_node.node import SubworkflowDeploymentNode
|
15
18
|
|
16
19
|
|
@@ -129,3 +132,165 @@ def test_run_workflow__any_array(vellum_client):
|
|
129
132
|
assert call_kwargs["inputs"] == [
|
130
133
|
WorkflowRequestJsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
|
131
134
|
]
|
135
|
+
|
136
|
+
|
137
|
+
def test_run_workflow__no_deployment():
|
138
|
+
"""Confirm that we raise error when running a subworkflow deployment node with no deployment attribute set"""
|
139
|
+
|
140
|
+
# GIVEN a Subworkflow Deployment Node
|
141
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
142
|
+
subworkflow_inputs = {
|
143
|
+
"fruits": ["apple", "banana", "cherry"],
|
144
|
+
}
|
145
|
+
|
146
|
+
# WHEN/THEN running the node should raise a NodeException
|
147
|
+
node = ExampleSubworkflowDeploymentNode()
|
148
|
+
with pytest.raises(NodeException) as exc_info:
|
149
|
+
list(node.run())
|
150
|
+
|
151
|
+
# AND the error message should be correct
|
152
|
+
assert exc_info.value.code == WorkflowErrorCode.NODE_EXECUTION
|
153
|
+
assert "Expected subworkflow deployment attribute to be either a UUID or STR, got None instead" in str(
|
154
|
+
exc_info.value
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
def test_run_workflow__hyphenated_output(vellum_client):
|
159
|
+
"""Confirm that we can successfully handle subworkflow outputs with hyphenated names"""
|
160
|
+
|
161
|
+
# GIVEN a Subworkflow Deployment Node
|
162
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
163
|
+
deployment = "example_subworkflow_deployment"
|
164
|
+
subworkflow_inputs = {
|
165
|
+
"test_input": "test_value",
|
166
|
+
}
|
167
|
+
|
168
|
+
class Outputs(SubworkflowDeploymentNode.Outputs):
|
169
|
+
final_output_copy: str
|
170
|
+
|
171
|
+
# AND we know what the Subworkflow Deployment will respond with
|
172
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
173
|
+
execution_id = str(uuid4())
|
174
|
+
expected_events: List[WorkflowStreamEvent] = [
|
175
|
+
WorkflowExecutionWorkflowResultEvent(
|
176
|
+
execution_id=execution_id,
|
177
|
+
data=WorkflowResultEvent(
|
178
|
+
id=str(uuid4()),
|
179
|
+
state="INITIATED",
|
180
|
+
ts=datetime.now(),
|
181
|
+
),
|
182
|
+
),
|
183
|
+
WorkflowExecutionWorkflowResultEvent(
|
184
|
+
execution_id=execution_id,
|
185
|
+
data=WorkflowResultEvent(
|
186
|
+
id=str(uuid4()),
|
187
|
+
state="FULFILLED",
|
188
|
+
ts=datetime.now(),
|
189
|
+
outputs=[
|
190
|
+
WorkflowOutputString(
|
191
|
+
id=str(uuid4()),
|
192
|
+
name="final-output_copy", # Note the hyphen here
|
193
|
+
value="test success",
|
194
|
+
)
|
195
|
+
],
|
196
|
+
),
|
197
|
+
),
|
198
|
+
]
|
199
|
+
yield from expected_events
|
200
|
+
|
201
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
202
|
+
|
203
|
+
# WHEN we run the node
|
204
|
+
node = ExampleSubworkflowDeploymentNode()
|
205
|
+
events = list(node.run())
|
206
|
+
|
207
|
+
# THEN the node should have completed successfully
|
208
|
+
assert events[-1].name == "final_output_copy" # Note the underscore here
|
209
|
+
assert events[-1].value == "test success"
|
210
|
+
|
211
|
+
|
212
|
+
@pytest.mark.parametrize(
|
213
|
+
["exception", "expected_code", "expected_message"],
|
214
|
+
[
|
215
|
+
(
|
216
|
+
ApiError(status_code=400, body={"detail": "Missing required input variable: 'foo'"}),
|
217
|
+
WorkflowErrorCode.INVALID_INPUTS,
|
218
|
+
"Missing required input variable: 'foo'",
|
219
|
+
),
|
220
|
+
(
|
221
|
+
ApiError(status_code=400, body={"message": "Missing required input variable: 'foo'"}),
|
222
|
+
WorkflowErrorCode.INVALID_INPUTS,
|
223
|
+
"Failed to execute Subworkflow Deployment",
|
224
|
+
),
|
225
|
+
(
|
226
|
+
ApiError(status_code=400, body="Missing required input variable: 'foo'"),
|
227
|
+
WorkflowErrorCode.INTERNAL_ERROR,
|
228
|
+
"Failed to execute Subworkflow Deployment",
|
229
|
+
),
|
230
|
+
(
|
231
|
+
ApiError(status_code=None, body={"detail": "Missing required input variable: 'foo'"}),
|
232
|
+
WorkflowErrorCode.INTERNAL_ERROR,
|
233
|
+
"Failed to execute Subworkflow Deployment",
|
234
|
+
),
|
235
|
+
(
|
236
|
+
ApiError(status_code=500, body={"detail": "Missing required input variable: 'foo'"}),
|
237
|
+
WorkflowErrorCode.INTERNAL_ERROR,
|
238
|
+
"Failed to execute Subworkflow Deployment",
|
239
|
+
),
|
240
|
+
],
|
241
|
+
ids=["400", "invalid_dict", "invalid_body", "no_status_code", "500"],
|
242
|
+
)
|
243
|
+
def test_subworkflow_deployment_node__api_error__invalid_inputs_node_exception(
|
244
|
+
vellum_client, exception, expected_code, expected_message
|
245
|
+
):
|
246
|
+
# GIVEN a prompt node with an invalid model name
|
247
|
+
class MyNode(SubworkflowDeploymentNode):
|
248
|
+
deployment = "example_subworkflow_deployment"
|
249
|
+
subworkflow_inputs = {
|
250
|
+
"not_foo": "bar",
|
251
|
+
}
|
252
|
+
|
253
|
+
# AND the Subworkflow Deployment API call fails
|
254
|
+
def _side_effect(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
255
|
+
if kwargs.get("_mock_condition_to_induce_an_error"):
|
256
|
+
yield WorkflowExecutionWorkflowResultEvent(
|
257
|
+
execution_id=str(uuid4()),
|
258
|
+
data=WorkflowResultEvent(
|
259
|
+
id=str(uuid4()),
|
260
|
+
state="INITIATED",
|
261
|
+
ts=datetime.now(),
|
262
|
+
),
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
raise exception
|
266
|
+
|
267
|
+
# AND the vellum client execute workflow stream raises a 4xx error
|
268
|
+
vellum_client.execute_workflow_stream.side_effect = _side_effect
|
269
|
+
|
270
|
+
# WHEN the node is run
|
271
|
+
with pytest.raises(NodeException) as e:
|
272
|
+
list(MyNode().run())
|
273
|
+
|
274
|
+
# THEN the node raises the correct NodeException
|
275
|
+
assert e.value.code == expected_code
|
276
|
+
assert e.value.message == expected_message
|
277
|
+
|
278
|
+
|
279
|
+
def test_subworkflow_deployment_node__immediate_api_error__node_exception(vellum_client):
|
280
|
+
# GIVEN a prompt node with an invalid model name
|
281
|
+
class MyNode(SubworkflowDeploymentNode):
|
282
|
+
deployment = "example_subworkflow_deployment"
|
283
|
+
subworkflow_inputs = {
|
284
|
+
"not_foo": "bar",
|
285
|
+
}
|
286
|
+
|
287
|
+
# AND the vellum client execute workflow stream raises a 4xx error
|
288
|
+
vellum_client.execute_workflow_stream.side_effect = ApiError(status_code=404, body={"detail": "Not found"})
|
289
|
+
|
290
|
+
# WHEN the node is run
|
291
|
+
with pytest.raises(NodeException) as e:
|
292
|
+
list(MyNode().run())
|
293
|
+
|
294
|
+
# THEN the node raises the correct NodeException
|
295
|
+
assert e.value.code == WorkflowErrorCode.INVALID_INPUTS
|
296
|
+
assert e.value.message == "Not found"
|
@@ -74,5 +74,7 @@ def test_text_prompt_deployment_node__basic(vellum_client):
|
|
74
74
|
prompt_deployment_name="my-deployment",
|
75
75
|
raw_overrides=OMIT,
|
76
76
|
release_tag="LATEST",
|
77
|
-
request_options={
|
77
|
+
request_options={
|
78
|
+
"additional_body_parameters": {"execution_context": {"parent_context": None, "trace_id": None}}
|
79
|
+
},
|
78
80
|
)
|
vellum/workflows/outputs/base.py
CHANGED
@@ -32,7 +32,7 @@ class BaseOutput(Generic[_Delta, _Accumulated]):
|
|
32
32
|
if value is not undefined and delta is not undefined:
|
33
33
|
raise ValueError("Cannot set both value and delta")
|
34
34
|
|
35
|
-
self._name = name
|
35
|
+
self._name = name.replace("-", "_") # Convert hyphens to underscores for valid python variable names
|
36
36
|
self._value = value
|
37
37
|
self._delta = delta
|
38
38
|
|
@@ -7,7 +7,7 @@ from uuid import UUID
|
|
7
7
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
|
8
8
|
|
9
9
|
from vellum.workflows.constants import undefined
|
10
|
-
from vellum.workflows.context import execution_context, get_parent_context
|
10
|
+
from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context, get_parent_context
|
11
11
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
12
12
|
from vellum.workflows.edges.edge import Edge
|
13
13
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
@@ -29,7 +29,7 @@ from vellum.workflows.events.node import (
|
|
29
29
|
NodeExecutionRejectedBody,
|
30
30
|
NodeExecutionStreamingBody,
|
31
31
|
)
|
32
|
-
from vellum.workflows.events.types import BaseEvent, NodeParentContext,
|
32
|
+
from vellum.workflows.events.types import BaseEvent, NodeParentContext, WorkflowParentContext
|
33
33
|
from vellum.workflows.events.workflow import (
|
34
34
|
WorkflowExecutionFulfilledBody,
|
35
35
|
WorkflowExecutionInitiatedBody,
|
@@ -75,8 +75,8 @@ class WorkflowRunner(Generic[StateType]):
|
|
75
75
|
external_inputs: Optional[ExternalInputsArg] = None,
|
76
76
|
cancel_signal: Optional[ThreadingEvent] = None,
|
77
77
|
node_output_mocks: Optional[MockNodeExecutionArg] = None,
|
78
|
-
parent_context: Optional[ParentContext] = None,
|
79
78
|
max_concurrency: Optional[int] = None,
|
79
|
+
init_execution_context: Optional[ExecutionContext] = None,
|
80
80
|
):
|
81
81
|
if state and external_inputs:
|
82
82
|
raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
|
@@ -98,6 +98,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
98
98
|
elif external_inputs:
|
99
99
|
self._initial_state = self.workflow.get_most_recent_state()
|
100
100
|
for descriptor, value in external_inputs.items():
|
101
|
+
if not any(isinstance(value, type_) for type_ in descriptor.types):
|
102
|
+
raise NodeException(
|
103
|
+
f"Invalid external input type for {descriptor.name}",
|
104
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
105
|
+
)
|
101
106
|
self._initial_state.meta.external_inputs[descriptor] = value
|
102
107
|
|
103
108
|
self._entrypoints = [
|
@@ -133,7 +138,8 @@ class WorkflowRunner(Generic[StateType]):
|
|
133
138
|
|
134
139
|
self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
|
135
140
|
self._cancel_signal = cancel_signal
|
136
|
-
self.
|
141
|
+
self._execution_context = init_execution_context or get_execution_context()
|
142
|
+
self._parent_context = self._execution_context.parent_context
|
137
143
|
|
138
144
|
setattr(
|
139
145
|
self._initial_state,
|
@@ -196,7 +202,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
196
202
|
break
|
197
203
|
|
198
204
|
if not was_mocked:
|
199
|
-
with execution_context(parent_context=updated_parent_context):
|
205
|
+
with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
|
200
206
|
node_run_response = node.run()
|
201
207
|
|
202
208
|
ports = node.Ports()
|
@@ -243,7 +249,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
243
249
|
),
|
244
250
|
)
|
245
251
|
|
246
|
-
with execution_context(parent_context=updated_parent_context):
|
252
|
+
with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
|
247
253
|
for output in node_run_response:
|
248
254
|
invoked_ports = output > ports
|
249
255
|
if output.is_initiated:
|
@@ -346,7 +352,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
346
352
|
if parent_context is None:
|
347
353
|
parent_context = get_parent_context() or self._parent_context
|
348
354
|
|
349
|
-
with execution_context(parent_context=parent_context):
|
355
|
+
with execution_context(parent_context=parent_context, trace_id=node.state.meta.trace_id):
|
350
356
|
self._run_work_item(node, span_id)
|
351
357
|
|
352
358
|
def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
|
@@ -524,7 +530,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
524
530
|
for node_cls in self._entrypoints:
|
525
531
|
try:
|
526
532
|
if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
|
527
|
-
with execution_context(parent_context=current_parent):
|
533
|
+
with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
|
528
534
|
self._run_node_if_ready(self._initial_state, node_cls)
|
529
535
|
else:
|
530
536
|
self._concurrency_queue.put((self._initial_state, node_cls, None))
|
@@ -551,7 +557,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
551
557
|
|
552
558
|
self._workflow_event_outer_queue.put(event)
|
553
559
|
|
554
|
-
with execution_context(parent_context=current_parent):
|
560
|
+
with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
|
555
561
|
rejection_error = self._handle_work_item_event(event)
|
556
562
|
|
557
563
|
if rejection_error:
|
@@ -562,7 +568,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
562
568
|
while event := self._workflow_event_inner_queue.get_nowait():
|
563
569
|
self._workflow_event_outer_queue.put(event)
|
564
570
|
|
565
|
-
with execution_context(parent_context=current_parent):
|
571
|
+
with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
|
566
572
|
rejection_error = self._handle_work_item_event(event)
|
567
573
|
|
568
574
|
if rejection_error:
|
@@ -3,7 +3,7 @@ from queue import Queue
|
|
3
3
|
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
4
4
|
|
5
5
|
from vellum import Vellum
|
6
|
-
from vellum.workflows.
|
6
|
+
from vellum.workflows.context import ExecutionContext, get_execution_context
|
7
7
|
from vellum.workflows.nodes.mocks import MockNodeExecution, MockNodeExecutionArg
|
8
8
|
from vellum.workflows.outputs.base import BaseOutputs
|
9
9
|
from vellum.workflows.references.constant import ConstantValueReference
|
@@ -18,12 +18,14 @@ class WorkflowContext:
|
|
18
18
|
self,
|
19
19
|
*,
|
20
20
|
vellum_client: Optional[Vellum] = None,
|
21
|
-
|
21
|
+
execution_context: Optional[ExecutionContext] = None,
|
22
22
|
):
|
23
23
|
self._vellum_client = vellum_client
|
24
|
-
self._parent_context = parent_context
|
25
24
|
self._event_queue: Optional[Queue["WorkflowEvent"]] = None
|
26
25
|
self._node_output_mocks_map: Dict[Type[BaseOutputs], List[MockNodeExecution]] = {}
|
26
|
+
self._execution_context = get_execution_context()
|
27
|
+
if not self._execution_context.parent_context and execution_context:
|
28
|
+
self._execution_context = execution_context
|
27
29
|
|
28
30
|
@cached_property
|
29
31
|
def vellum_client(self) -> Vellum:
|
@@ -33,10 +35,8 @@ class WorkflowContext:
|
|
33
35
|
return create_vellum_client()
|
34
36
|
|
35
37
|
@cached_property
|
36
|
-
def
|
37
|
-
|
38
|
-
return self._parent_context
|
39
|
-
return None
|
38
|
+
def execution_context(self) -> ExecutionContext:
|
39
|
+
return self._execution_context
|
40
40
|
|
41
41
|
@cached_property
|
42
42
|
def node_output_mocks_map(self) -> Dict[Type[BaseOutputs], List[MockNodeExecution]]:
|
@@ -24,6 +24,7 @@ from typing import (
|
|
24
24
|
get_args,
|
25
25
|
)
|
26
26
|
|
27
|
+
from vellum.workflows.context import get_execution_context
|
27
28
|
from vellum.workflows.edges import Edge
|
28
29
|
from vellum.workflows.emitters.base import BaseWorkflowEmitter
|
29
30
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
@@ -171,6 +172,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
171
172
|
self.resolvers = resolvers or (self.resolvers if hasattr(self, "resolvers") else [])
|
172
173
|
self._context = context or WorkflowContext()
|
173
174
|
self._store = Store()
|
175
|
+
self._execution_context = self._context.execution_context
|
174
176
|
|
175
177
|
self.validate()
|
176
178
|
|
@@ -178,48 +180,64 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
178
180
|
def context(self) -> WorkflowContext:
|
179
181
|
return self._context
|
180
182
|
|
181
|
-
@
|
182
|
-
def
|
183
|
-
original_graph = cls.graph
|
184
|
-
if isinstance(original_graph, Graph):
|
185
|
-
return [original_graph]
|
186
|
-
if isinstance(original_graph, set):
|
187
|
-
return [
|
188
|
-
subgraph if isinstance(subgraph, Graph) else Graph.from_node(subgraph) for subgraph in original_graph
|
189
|
-
]
|
190
|
-
if issubclass(original_graph, BaseNode):
|
191
|
-
return [Graph.from_node(original_graph)]
|
192
|
-
|
193
|
-
raise ValueError(f"Unexpected graph type: {original_graph.__class__}")
|
194
|
-
|
195
|
-
@classmethod
|
196
|
-
def get_edges(cls) -> Iterator[Edge]:
|
183
|
+
@staticmethod
|
184
|
+
def _resolve_graph(graph: GraphAttribute) -> List[Graph]:
|
197
185
|
"""
|
198
|
-
|
199
|
-
ensure uniqueness, and the iterator to preserve order.
|
186
|
+
Resolves a single graph source to a list of Graph objects.
|
200
187
|
"""
|
188
|
+
if isinstance(graph, Graph):
|
189
|
+
return [graph]
|
190
|
+
if isinstance(graph, set):
|
191
|
+
graphs = []
|
192
|
+
for item in graph:
|
193
|
+
if isinstance(item, Graph):
|
194
|
+
graphs.append(item)
|
195
|
+
elif issubclass(item, BaseNode):
|
196
|
+
graphs.append(Graph.from_node(item))
|
197
|
+
else:
|
198
|
+
raise ValueError(f"Unexpected graph type: {type(item)}")
|
199
|
+
return graphs
|
200
|
+
if issubclass(graph, BaseNode):
|
201
|
+
return [Graph.from_node(graph)]
|
202
|
+
raise ValueError(f"Unexpected graph type: {type(graph)}")
|
201
203
|
|
204
|
+
@staticmethod
|
205
|
+
def _get_edges_from_subgraphs(subgraphs: Iterable[Graph]) -> Iterator[Edge]:
|
202
206
|
edges = set()
|
203
|
-
subgraphs = cls.get_subgraphs()
|
204
207
|
for subgraph in subgraphs:
|
205
208
|
for edge in subgraph.edges:
|
206
209
|
if edge not in edges:
|
207
210
|
edges.add(edge)
|
208
211
|
yield edge
|
209
212
|
|
213
|
+
@staticmethod
|
214
|
+
def _get_nodes_from_subgraphs(subgraphs: Iterable[Graph]) -> Iterator[Type[BaseNode]]:
|
215
|
+
nodes = set()
|
216
|
+
for subgraph in subgraphs:
|
217
|
+
for node in subgraph.nodes:
|
218
|
+
if node not in nodes:
|
219
|
+
nodes.add(node)
|
220
|
+
yield node
|
221
|
+
|
222
|
+
@classmethod
|
223
|
+
def get_subgraphs(cls) -> List[Graph]:
|
224
|
+
return cls._resolve_graph(cls.graph)
|
225
|
+
|
226
|
+
@classmethod
|
227
|
+
def get_edges(cls) -> Iterator[Edge]:
|
228
|
+
"""
|
229
|
+
Returns an iterator over the edges in the workflow. We use a set to
|
230
|
+
ensure uniqueness, and the iterator to preserve order.
|
231
|
+
"""
|
232
|
+
return cls._get_edges_from_subgraphs(cls.get_subgraphs())
|
233
|
+
|
210
234
|
@classmethod
|
211
235
|
def get_nodes(cls) -> Iterator[Type[BaseNode]]:
|
212
236
|
"""
|
213
237
|
Returns an iterator over the nodes in the workflow. We use a set to
|
214
238
|
ensure uniqueness, and the iterator to preserve order.
|
215
239
|
"""
|
216
|
-
|
217
|
-
nodes = set()
|
218
|
-
for subgraph in cls.get_subgraphs():
|
219
|
-
for node in subgraph.nodes:
|
220
|
-
if node not in nodes:
|
221
|
-
nodes.add(node)
|
222
|
-
yield node
|
240
|
+
return cls._get_nodes_from_subgraphs(cls.get_subgraphs())
|
223
241
|
|
224
242
|
@classmethod
|
225
243
|
def get_unused_subgraphs(cls) -> List[Graph]:
|
@@ -228,19 +246,9 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
228
246
|
"""
|
229
247
|
if not hasattr(cls, "unused_graphs"):
|
230
248
|
return []
|
231
|
-
|
232
249
|
graphs = []
|
233
250
|
for item in cls.unused_graphs:
|
234
|
-
|
235
|
-
graphs.append(item)
|
236
|
-
elif isinstance(item, set):
|
237
|
-
for subitem in item:
|
238
|
-
if isinstance(subitem, Graph):
|
239
|
-
graphs.append(subitem)
|
240
|
-
elif issubclass(subitem, BaseNode):
|
241
|
-
graphs.append(Graph.from_node(subitem))
|
242
|
-
elif issubclass(item, BaseNode):
|
243
|
-
graphs.append(Graph.from_node(item))
|
251
|
+
graphs.extend(cls._resolve_graph(item))
|
244
252
|
return graphs
|
245
253
|
|
246
254
|
@classmethod
|
@@ -248,29 +256,14 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
248
256
|
"""
|
249
257
|
Returns an iterator over the nodes that are defined but not used in the graph.
|
250
258
|
"""
|
251
|
-
|
252
|
-
yield from ()
|
253
|
-
else:
|
254
|
-
nodes = set()
|
255
|
-
subgraphs = cls.get_unused_subgraphs()
|
256
|
-
for subgraph in subgraphs:
|
257
|
-
for node in subgraph.nodes:
|
258
|
-
if node not in nodes:
|
259
|
-
nodes.add(node)
|
260
|
-
yield node
|
259
|
+
return cls._get_nodes_from_subgraphs(cls.get_unused_subgraphs())
|
261
260
|
|
262
261
|
@classmethod
|
263
262
|
def get_unused_edges(cls) -> Iterator[Edge]:
|
264
263
|
"""
|
265
264
|
Returns an iterator over edges that are defined but not used in the graph.
|
266
265
|
"""
|
267
|
-
|
268
|
-
subgraphs = cls.get_unused_subgraphs()
|
269
|
-
for subgraph in subgraphs:
|
270
|
-
for edge in subgraph.edges:
|
271
|
-
if edge not in edges:
|
272
|
-
edges.add(edge)
|
273
|
-
yield edge
|
266
|
+
return cls._get_edges_from_subgraphs(cls.get_unused_subgraphs())
|
274
267
|
|
275
268
|
@classmethod
|
276
269
|
def get_entrypoints(cls) -> Iterable[Type[BaseNode]]:
|
@@ -329,8 +322,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
329
322
|
external_inputs=external_inputs,
|
330
323
|
cancel_signal=cancel_signal,
|
331
324
|
node_output_mocks=node_output_mocks,
|
332
|
-
parent_context=self._context.parent_context,
|
333
325
|
max_concurrency=max_concurrency,
|
326
|
+
init_execution_context=self._execution_context,
|
334
327
|
).stream()
|
335
328
|
first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
|
336
329
|
last_event = None
|
@@ -440,8 +433,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
440
433
|
external_inputs=external_inputs,
|
441
434
|
cancel_signal=cancel_signal,
|
442
435
|
node_output_mocks=node_output_mocks,
|
443
|
-
parent_context=self.context.parent_context,
|
444
436
|
max_concurrency=max_concurrency,
|
437
|
+
init_execution_context=self._execution_context,
|
445
438
|
).stream():
|
446
439
|
if should_yield(self.__class__, event):
|
447
440
|
yield event
|
@@ -488,10 +481,19 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
488
481
|
return self.get_inputs_class()()
|
489
482
|
|
490
483
|
def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
|
484
|
+
execution_context = get_execution_context()
|
491
485
|
return self.get_state_class()(
|
492
|
-
meta=
|
493
|
-
|
494
|
-
|
486
|
+
meta=(
|
487
|
+
StateMeta(
|
488
|
+
parent=self._parent_state,
|
489
|
+
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
490
|
+
trace_id=execution_context.trace_id,
|
491
|
+
)
|
492
|
+
if execution_context and execution_context.trace_id
|
493
|
+
else StateMeta(
|
494
|
+
parent=self._parent_state,
|
495
|
+
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
496
|
+
)
|
495
497
|
)
|
496
498
|
)
|
497
499
|
|