vellum-ai 0.14.6__py3-none-any.whl → 0.14.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +12 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +12 -0
- vellum/client/types/array_chat_message_content_item.py +6 -1
- vellum/client/types/array_chat_message_content_item_request.py +2 -0
- vellum/client/types/chat_message_content.py +2 -0
- vellum/client/types/chat_message_content_request.py +2 -0
- vellum/client/types/document_chat_message_content.py +25 -0
- vellum/client/types/document_chat_message_content_request.py +25 -0
- vellum/client/types/document_vellum_value.py +25 -0
- vellum/client/types/document_vellum_value_request.py +25 -0
- vellum/client/types/vellum_document.py +20 -0
- vellum/client/types/vellum_document_request.py +20 -0
- vellum/client/types/vellum_value.py +2 -0
- vellum/client/types/vellum_value_request.py +2 -0
- vellum/client/types/vellum_variable_type.py +1 -0
- vellum/types/document_chat_message_content.py +3 -0
- vellum/types/document_chat_message_content_request.py +3 -0
- vellum/types/document_vellum_value.py +3 -0
- vellum/types/document_vellum_value_request.py +3 -0
- vellum/types/vellum_document.py +3 -0
- vellum/types/vellum_document_request.py +3 -0
- vellum/workflows/exceptions.py +18 -0
- vellum/workflows/inputs/base.py +27 -1
- vellum/workflows/inputs/tests/__init__.py +0 -0
- vellum/workflows/inputs/tests/test_inputs.py +49 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -1
- vellum/workflows/nodes/core/map_node/node.py +7 -7
- vellum/workflows/nodes/core/try_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +2 -2
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +5 -4
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +4 -4
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +39 -15
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +142 -0
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +3 -1
- vellum/workflows/outputs/base.py +1 -1
- vellum/workflows/runner/runner.py +16 -10
- vellum/workflows/state/context.py +7 -7
- vellum/workflows/workflows/base.py +16 -5
- vellum/workflows/workflows/tests/test_base_workflow.py +131 -40
- {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/RECORD +54 -38
- vellum_cli/__init__.py +36 -0
- vellum_cli/init.py +128 -0
- vellum_cli/pull.py +6 -3
- vellum_cli/tests/test_init.py +355 -0
- vellum_cli/tests/test_pull.py +127 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +4 -4
- vellum_ee/workflows/display/vellum.py +0 -4
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +29 -0
- {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/entry_points.txt +0 -0
@@ -146,22 +146,22 @@ def test_inline_prompt_node__function_definitions(vellum_adhoc_prompt_client):
|
|
146
146
|
(
|
147
147
|
ApiError(status_code=404, body={"message": "Model not found"}),
|
148
148
|
WorkflowErrorCode.INVALID_INPUTS,
|
149
|
-
"Failed to execute
|
149
|
+
"Failed to execute Prompt",
|
150
150
|
),
|
151
151
|
(
|
152
152
|
ApiError(status_code=404, body="Model not found"),
|
153
153
|
WorkflowErrorCode.INTERNAL_ERROR,
|
154
|
-
"Failed to execute
|
154
|
+
"Failed to execute Prompt",
|
155
155
|
),
|
156
156
|
(
|
157
157
|
ApiError(status_code=None, body={"detail": "Model not found"}),
|
158
158
|
WorkflowErrorCode.INTERNAL_ERROR,
|
159
|
-
"Failed to execute
|
159
|
+
"Failed to execute Prompt",
|
160
160
|
),
|
161
161
|
(
|
162
162
|
ApiError(status_code=500, body={"detail": "Model not found"}),
|
163
163
|
WorkflowErrorCode.INTERNAL_ERROR,
|
164
|
-
"Failed to execute
|
164
|
+
"Failed to execute Prompt",
|
165
165
|
),
|
166
166
|
],
|
167
167
|
ids=["404", "invalid_dict", "invalid_body", "no_status_code", "500"],
|
@@ -12,10 +12,11 @@ from vellum import (
|
|
12
12
|
WorkflowRequestNumberInputRequest,
|
13
13
|
WorkflowRequestStringInputRequest,
|
14
14
|
)
|
15
|
+
from vellum.client.core.api_error import ApiError
|
15
16
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
16
17
|
from vellum.core import RequestOptions
|
17
18
|
from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
|
18
|
-
from vellum.workflows.context import
|
19
|
+
from vellum.workflows.context import get_execution_context
|
19
20
|
from vellum.workflows.errors import WorkflowErrorCode
|
20
21
|
from vellum.workflows.errors.types import workflow_event_error_to_workflow_error
|
21
22
|
from vellum.workflows.events.types import default_serializer
|
@@ -120,11 +121,13 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
120
121
|
return compiled_inputs
|
121
122
|
|
122
123
|
def run(self) -> Iterator[BaseOutput]:
|
123
|
-
|
124
|
-
parent_context =
|
124
|
+
current_context = get_execution_context()
|
125
|
+
parent_context = (
|
126
|
+
current_context.parent_context.model_dump(mode="json") if current_context.parent_context else None
|
127
|
+
)
|
125
128
|
request_options = self.request_options or RequestOptions()
|
126
129
|
request_options["additional_body_parameters"] = {
|
127
|
-
"execution_context": {"parent_context": parent_context},
|
130
|
+
"execution_context": {"parent_context": parent_context, "trace_id": current_context.trace_id},
|
128
131
|
**request_options.get("additional_body_parameters", {}),
|
129
132
|
}
|
130
133
|
|
@@ -137,17 +140,26 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
137
140
|
message="Expected subworkflow deployment attribute to be either a UUID or STR, got None instead",
|
138
141
|
)
|
139
142
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
143
|
+
try:
|
144
|
+
subworkflow_stream = self._context.vellum_client.execute_workflow_stream(
|
145
|
+
inputs=self._compile_subworkflow_inputs(),
|
146
|
+
workflow_deployment_id=deployment_id,
|
147
|
+
workflow_deployment_name=deployment_name,
|
148
|
+
release_tag=self.release_tag,
|
149
|
+
external_id=self.external_id,
|
150
|
+
event_types=["WORKFLOW"],
|
151
|
+
metadata=self.metadata,
|
152
|
+
request_options=request_options,
|
153
|
+
)
|
154
|
+
except ApiError as e:
|
155
|
+
self._handle_api_error(e)
|
156
|
+
|
157
|
+
# We don't use the INITIATED event anyway, so we can just skip it
|
158
|
+
# and use the exception handling to catch other api level errors
|
159
|
+
try:
|
160
|
+
next(subworkflow_stream)
|
161
|
+
except ApiError as e:
|
162
|
+
self._handle_api_error(e)
|
151
163
|
|
152
164
|
outputs: Optional[List[WorkflowOutput]] = None
|
153
165
|
fulfilled_output_names: Set[str] = set()
|
@@ -195,3 +207,15 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
|
|
195
207
|
name=output.name,
|
196
208
|
value=output.value,
|
197
209
|
)
|
210
|
+
|
211
|
+
def _handle_api_error(self, e: ApiError):
|
212
|
+
if e.status_code and e.status_code >= 400 and e.status_code < 500 and isinstance(e.body, dict):
|
213
|
+
raise NodeException(
|
214
|
+
message=e.body.get("detail", "Failed to execute Subworkflow Deployment"),
|
215
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
216
|
+
) from e
|
217
|
+
|
218
|
+
raise NodeException(
|
219
|
+
message="Failed to execute Subworkflow Deployment",
|
220
|
+
code=WorkflowErrorCode.INTERNAL_ERROR,
|
221
|
+
) from e
|
@@ -3,6 +3,7 @@ from datetime import datetime
|
|
3
3
|
from uuid import uuid4
|
4
4
|
from typing import Any, Iterator, List
|
5
5
|
|
6
|
+
from vellum.client.core.api_error import ApiError
|
6
7
|
from vellum.client.types.chat_message import ChatMessage
|
7
8
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
8
9
|
from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
|
@@ -152,3 +153,144 @@ def test_run_workflow__no_deployment():
|
|
152
153
|
assert "Expected subworkflow deployment attribute to be either a UUID or STR, got None instead" in str(
|
153
154
|
exc_info.value
|
154
155
|
)
|
156
|
+
|
157
|
+
|
158
|
+
def test_run_workflow__hyphenated_output(vellum_client):
|
159
|
+
"""Confirm that we can successfully handle subworkflow outputs with hyphenated names"""
|
160
|
+
|
161
|
+
# GIVEN a Subworkflow Deployment Node
|
162
|
+
class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
|
163
|
+
deployment = "example_subworkflow_deployment"
|
164
|
+
subworkflow_inputs = {
|
165
|
+
"test_input": "test_value",
|
166
|
+
}
|
167
|
+
|
168
|
+
class Outputs(SubworkflowDeploymentNode.Outputs):
|
169
|
+
final_output_copy: str
|
170
|
+
|
171
|
+
# AND we know what the Subworkflow Deployment will respond with
|
172
|
+
def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
173
|
+
execution_id = str(uuid4())
|
174
|
+
expected_events: List[WorkflowStreamEvent] = [
|
175
|
+
WorkflowExecutionWorkflowResultEvent(
|
176
|
+
execution_id=execution_id,
|
177
|
+
data=WorkflowResultEvent(
|
178
|
+
id=str(uuid4()),
|
179
|
+
state="INITIATED",
|
180
|
+
ts=datetime.now(),
|
181
|
+
),
|
182
|
+
),
|
183
|
+
WorkflowExecutionWorkflowResultEvent(
|
184
|
+
execution_id=execution_id,
|
185
|
+
data=WorkflowResultEvent(
|
186
|
+
id=str(uuid4()),
|
187
|
+
state="FULFILLED",
|
188
|
+
ts=datetime.now(),
|
189
|
+
outputs=[
|
190
|
+
WorkflowOutputString(
|
191
|
+
id=str(uuid4()),
|
192
|
+
name="final-output_copy", # Note the hyphen here
|
193
|
+
value="test success",
|
194
|
+
)
|
195
|
+
],
|
196
|
+
),
|
197
|
+
),
|
198
|
+
]
|
199
|
+
yield from expected_events
|
200
|
+
|
201
|
+
vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
|
202
|
+
|
203
|
+
# WHEN we run the node
|
204
|
+
node = ExampleSubworkflowDeploymentNode()
|
205
|
+
events = list(node.run())
|
206
|
+
|
207
|
+
# THEN the node should have completed successfully
|
208
|
+
assert events[-1].name == "final_output_copy" # Note the underscore here
|
209
|
+
assert events[-1].value == "test success"
|
210
|
+
|
211
|
+
|
212
|
+
@pytest.mark.parametrize(
|
213
|
+
["exception", "expected_code", "expected_message"],
|
214
|
+
[
|
215
|
+
(
|
216
|
+
ApiError(status_code=400, body={"detail": "Missing required input variable: 'foo'"}),
|
217
|
+
WorkflowErrorCode.INVALID_INPUTS,
|
218
|
+
"Missing required input variable: 'foo'",
|
219
|
+
),
|
220
|
+
(
|
221
|
+
ApiError(status_code=400, body={"message": "Missing required input variable: 'foo'"}),
|
222
|
+
WorkflowErrorCode.INVALID_INPUTS,
|
223
|
+
"Failed to execute Subworkflow Deployment",
|
224
|
+
),
|
225
|
+
(
|
226
|
+
ApiError(status_code=400, body="Missing required input variable: 'foo'"),
|
227
|
+
WorkflowErrorCode.INTERNAL_ERROR,
|
228
|
+
"Failed to execute Subworkflow Deployment",
|
229
|
+
),
|
230
|
+
(
|
231
|
+
ApiError(status_code=None, body={"detail": "Missing required input variable: 'foo'"}),
|
232
|
+
WorkflowErrorCode.INTERNAL_ERROR,
|
233
|
+
"Failed to execute Subworkflow Deployment",
|
234
|
+
),
|
235
|
+
(
|
236
|
+
ApiError(status_code=500, body={"detail": "Missing required input variable: 'foo'"}),
|
237
|
+
WorkflowErrorCode.INTERNAL_ERROR,
|
238
|
+
"Failed to execute Subworkflow Deployment",
|
239
|
+
),
|
240
|
+
],
|
241
|
+
ids=["400", "invalid_dict", "invalid_body", "no_status_code", "500"],
|
242
|
+
)
|
243
|
+
def test_subworkflow_deployment_node__api_error__invalid_inputs_node_exception(
|
244
|
+
vellum_client, exception, expected_code, expected_message
|
245
|
+
):
|
246
|
+
# GIVEN a prompt node with an invalid model name
|
247
|
+
class MyNode(SubworkflowDeploymentNode):
|
248
|
+
deployment = "example_subworkflow_deployment"
|
249
|
+
subworkflow_inputs = {
|
250
|
+
"not_foo": "bar",
|
251
|
+
}
|
252
|
+
|
253
|
+
# AND the Subworkflow Deployment API call fails
|
254
|
+
def _side_effect(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
|
255
|
+
if kwargs.get("_mock_condition_to_induce_an_error"):
|
256
|
+
yield WorkflowExecutionWorkflowResultEvent(
|
257
|
+
execution_id=str(uuid4()),
|
258
|
+
data=WorkflowResultEvent(
|
259
|
+
id=str(uuid4()),
|
260
|
+
state="INITIATED",
|
261
|
+
ts=datetime.now(),
|
262
|
+
),
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
raise exception
|
266
|
+
|
267
|
+
# AND the vellum client execute workflow stream raises a 4xx error
|
268
|
+
vellum_client.execute_workflow_stream.side_effect = _side_effect
|
269
|
+
|
270
|
+
# WHEN the node is run
|
271
|
+
with pytest.raises(NodeException) as e:
|
272
|
+
list(MyNode().run())
|
273
|
+
|
274
|
+
# THEN the node raises the correct NodeException
|
275
|
+
assert e.value.code == expected_code
|
276
|
+
assert e.value.message == expected_message
|
277
|
+
|
278
|
+
|
279
|
+
def test_subworkflow_deployment_node__immediate_api_error__node_exception(vellum_client):
|
280
|
+
# GIVEN a prompt node with an invalid model name
|
281
|
+
class MyNode(SubworkflowDeploymentNode):
|
282
|
+
deployment = "example_subworkflow_deployment"
|
283
|
+
subworkflow_inputs = {
|
284
|
+
"not_foo": "bar",
|
285
|
+
}
|
286
|
+
|
287
|
+
# AND the vellum client execute workflow stream raises a 4xx error
|
288
|
+
vellum_client.execute_workflow_stream.side_effect = ApiError(status_code=404, body={"detail": "Not found"})
|
289
|
+
|
290
|
+
# WHEN the node is run
|
291
|
+
with pytest.raises(NodeException) as e:
|
292
|
+
list(MyNode().run())
|
293
|
+
|
294
|
+
# THEN the node raises the correct NodeException
|
295
|
+
assert e.value.code == WorkflowErrorCode.INVALID_INPUTS
|
296
|
+
assert e.value.message == "Not found"
|
@@ -74,5 +74,7 @@ def test_text_prompt_deployment_node__basic(vellum_client):
|
|
74
74
|
prompt_deployment_name="my-deployment",
|
75
75
|
raw_overrides=OMIT,
|
76
76
|
release_tag="LATEST",
|
77
|
-
request_options={
|
77
|
+
request_options={
|
78
|
+
"additional_body_parameters": {"execution_context": {"parent_context": None, "trace_id": None}}
|
79
|
+
},
|
78
80
|
)
|
vellum/workflows/outputs/base.py
CHANGED
@@ -32,7 +32,7 @@ class BaseOutput(Generic[_Delta, _Accumulated]):
|
|
32
32
|
if value is not undefined and delta is not undefined:
|
33
33
|
raise ValueError("Cannot set both value and delta")
|
34
34
|
|
35
|
-
self._name = name
|
35
|
+
self._name = name.replace("-", "_") # Convert hyphens to underscores for valid python variable names
|
36
36
|
self._value = value
|
37
37
|
self._delta = delta
|
38
38
|
|
@@ -7,7 +7,7 @@ from uuid import UUID
|
|
7
7
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
|
8
8
|
|
9
9
|
from vellum.workflows.constants import undefined
|
10
|
-
from vellum.workflows.context import execution_context, get_parent_context
|
10
|
+
from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context, get_parent_context
|
11
11
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
12
12
|
from vellum.workflows.edges.edge import Edge
|
13
13
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
@@ -29,7 +29,7 @@ from vellum.workflows.events.node import (
|
|
29
29
|
NodeExecutionRejectedBody,
|
30
30
|
NodeExecutionStreamingBody,
|
31
31
|
)
|
32
|
-
from vellum.workflows.events.types import BaseEvent, NodeParentContext,
|
32
|
+
from vellum.workflows.events.types import BaseEvent, NodeParentContext, WorkflowParentContext
|
33
33
|
from vellum.workflows.events.workflow import (
|
34
34
|
WorkflowExecutionFulfilledBody,
|
35
35
|
WorkflowExecutionInitiatedBody,
|
@@ -75,8 +75,8 @@ class WorkflowRunner(Generic[StateType]):
|
|
75
75
|
external_inputs: Optional[ExternalInputsArg] = None,
|
76
76
|
cancel_signal: Optional[ThreadingEvent] = None,
|
77
77
|
node_output_mocks: Optional[MockNodeExecutionArg] = None,
|
78
|
-
parent_context: Optional[ParentContext] = None,
|
79
78
|
max_concurrency: Optional[int] = None,
|
79
|
+
init_execution_context: Optional[ExecutionContext] = None,
|
80
80
|
):
|
81
81
|
if state and external_inputs:
|
82
82
|
raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
|
@@ -98,6 +98,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
98
98
|
elif external_inputs:
|
99
99
|
self._initial_state = self.workflow.get_most_recent_state()
|
100
100
|
for descriptor, value in external_inputs.items():
|
101
|
+
if not any(isinstance(value, type_) for type_ in descriptor.types):
|
102
|
+
raise NodeException(
|
103
|
+
f"Invalid external input type for {descriptor.name}",
|
104
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
105
|
+
)
|
101
106
|
self._initial_state.meta.external_inputs[descriptor] = value
|
102
107
|
|
103
108
|
self._entrypoints = [
|
@@ -133,7 +138,8 @@ class WorkflowRunner(Generic[StateType]):
|
|
133
138
|
|
134
139
|
self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
|
135
140
|
self._cancel_signal = cancel_signal
|
136
|
-
self.
|
141
|
+
self._execution_context = init_execution_context or get_execution_context()
|
142
|
+
self._parent_context = self._execution_context.parent_context
|
137
143
|
|
138
144
|
setattr(
|
139
145
|
self._initial_state,
|
@@ -196,7 +202,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
196
202
|
break
|
197
203
|
|
198
204
|
if not was_mocked:
|
199
|
-
with execution_context(parent_context=updated_parent_context):
|
205
|
+
with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
|
200
206
|
node_run_response = node.run()
|
201
207
|
|
202
208
|
ports = node.Ports()
|
@@ -243,7 +249,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
243
249
|
),
|
244
250
|
)
|
245
251
|
|
246
|
-
with execution_context(parent_context=updated_parent_context):
|
252
|
+
with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
|
247
253
|
for output in node_run_response:
|
248
254
|
invoked_ports = output > ports
|
249
255
|
if output.is_initiated:
|
@@ -346,7 +352,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
346
352
|
if parent_context is None:
|
347
353
|
parent_context = get_parent_context() or self._parent_context
|
348
354
|
|
349
|
-
with execution_context(parent_context=parent_context):
|
355
|
+
with execution_context(parent_context=parent_context, trace_id=node.state.meta.trace_id):
|
350
356
|
self._run_work_item(node, span_id)
|
351
357
|
|
352
358
|
def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
|
@@ -524,7 +530,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
524
530
|
for node_cls in self._entrypoints:
|
525
531
|
try:
|
526
532
|
if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
|
527
|
-
with execution_context(parent_context=current_parent):
|
533
|
+
with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
|
528
534
|
self._run_node_if_ready(self._initial_state, node_cls)
|
529
535
|
else:
|
530
536
|
self._concurrency_queue.put((self._initial_state, node_cls, None))
|
@@ -551,7 +557,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
551
557
|
|
552
558
|
self._workflow_event_outer_queue.put(event)
|
553
559
|
|
554
|
-
with execution_context(parent_context=current_parent):
|
560
|
+
with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
|
555
561
|
rejection_error = self._handle_work_item_event(event)
|
556
562
|
|
557
563
|
if rejection_error:
|
@@ -562,7 +568,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
562
568
|
while event := self._workflow_event_inner_queue.get_nowait():
|
563
569
|
self._workflow_event_outer_queue.put(event)
|
564
570
|
|
565
|
-
with execution_context(parent_context=current_parent):
|
571
|
+
with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
|
566
572
|
rejection_error = self._handle_work_item_event(event)
|
567
573
|
|
568
574
|
if rejection_error:
|
@@ -3,7 +3,7 @@ from queue import Queue
|
|
3
3
|
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
4
4
|
|
5
5
|
from vellum import Vellum
|
6
|
-
from vellum.workflows.
|
6
|
+
from vellum.workflows.context import ExecutionContext, get_execution_context
|
7
7
|
from vellum.workflows.nodes.mocks import MockNodeExecution, MockNodeExecutionArg
|
8
8
|
from vellum.workflows.outputs.base import BaseOutputs
|
9
9
|
from vellum.workflows.references.constant import ConstantValueReference
|
@@ -18,12 +18,14 @@ class WorkflowContext:
|
|
18
18
|
self,
|
19
19
|
*,
|
20
20
|
vellum_client: Optional[Vellum] = None,
|
21
|
-
|
21
|
+
execution_context: Optional[ExecutionContext] = None,
|
22
22
|
):
|
23
23
|
self._vellum_client = vellum_client
|
24
|
-
self._parent_context = parent_context
|
25
24
|
self._event_queue: Optional[Queue["WorkflowEvent"]] = None
|
26
25
|
self._node_output_mocks_map: Dict[Type[BaseOutputs], List[MockNodeExecution]] = {}
|
26
|
+
self._execution_context = get_execution_context()
|
27
|
+
if not self._execution_context.parent_context and execution_context:
|
28
|
+
self._execution_context = execution_context
|
27
29
|
|
28
30
|
@cached_property
|
29
31
|
def vellum_client(self) -> Vellum:
|
@@ -33,10 +35,8 @@ class WorkflowContext:
|
|
33
35
|
return create_vellum_client()
|
34
36
|
|
35
37
|
@cached_property
|
36
|
-
def
|
37
|
-
|
38
|
-
return self._parent_context
|
39
|
-
return None
|
38
|
+
def execution_context(self) -> ExecutionContext:
|
39
|
+
return self._execution_context
|
40
40
|
|
41
41
|
@cached_property
|
42
42
|
def node_output_mocks_map(self) -> Dict[Type[BaseOutputs], List[MockNodeExecution]]:
|
@@ -24,6 +24,7 @@ from typing import (
|
|
24
24
|
get_args,
|
25
25
|
)
|
26
26
|
|
27
|
+
from vellum.workflows.context import get_execution_context
|
27
28
|
from vellum.workflows.edges import Edge
|
28
29
|
from vellum.workflows.emitters.base import BaseWorkflowEmitter
|
29
30
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
@@ -171,6 +172,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
171
172
|
self.resolvers = resolvers or (self.resolvers if hasattr(self, "resolvers") else [])
|
172
173
|
self._context = context or WorkflowContext()
|
173
174
|
self._store = Store()
|
175
|
+
self._execution_context = self._context.execution_context
|
174
176
|
|
175
177
|
self.validate()
|
176
178
|
|
@@ -320,8 +322,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
320
322
|
external_inputs=external_inputs,
|
321
323
|
cancel_signal=cancel_signal,
|
322
324
|
node_output_mocks=node_output_mocks,
|
323
|
-
parent_context=self._context.parent_context,
|
324
325
|
max_concurrency=max_concurrency,
|
326
|
+
init_execution_context=self._execution_context,
|
325
327
|
).stream()
|
326
328
|
first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
|
327
329
|
last_event = None
|
@@ -431,8 +433,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
431
433
|
external_inputs=external_inputs,
|
432
434
|
cancel_signal=cancel_signal,
|
433
435
|
node_output_mocks=node_output_mocks,
|
434
|
-
parent_context=self.context.parent_context,
|
435
436
|
max_concurrency=max_concurrency,
|
437
|
+
init_execution_context=self._execution_context,
|
436
438
|
).stream():
|
437
439
|
if should_yield(self.__class__, event):
|
438
440
|
yield event
|
@@ -479,10 +481,19 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
479
481
|
return self.get_inputs_class()()
|
480
482
|
|
481
483
|
def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
|
484
|
+
execution_context = get_execution_context()
|
482
485
|
return self.get_state_class()(
|
483
|
-
meta=
|
484
|
-
|
485
|
-
|
486
|
+
meta=(
|
487
|
+
StateMeta(
|
488
|
+
parent=self._parent_state,
|
489
|
+
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
490
|
+
trace_id=execution_context.trace_id,
|
491
|
+
)
|
492
|
+
if execution_context and execution_context.trace_id
|
493
|
+
else StateMeta(
|
494
|
+
parent=self._parent_state,
|
495
|
+
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
496
|
+
)
|
486
497
|
)
|
487
498
|
)
|
488
499
|
|