vellum-ai 0.14.5__py3-none-any.whl → 0.14.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +18 -0
- vellum/client/__init__.py +8 -8
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/resources/__init__.py +2 -0
- vellum/client/resources/workflow_sandboxes/__init__.py +3 -0
- vellum/client/resources/workflow_sandboxes/client.py +146 -0
- vellum/client/resources/workflow_sandboxes/types/__init__.py +5 -0
- vellum/client/resources/workflow_sandboxes/types/list_workflow_sandbox_examples_request_tag.py +5 -0
- vellum/client/types/__init__.py +16 -0
- vellum/client/types/array_chat_message_content_item.py +6 -1
- vellum/client/types/array_chat_message_content_item_request.py +2 -0
- vellum/client/types/chat_message_content.py +2 -0
- vellum/client/types/chat_message_content_request.py +2 -0
- vellum/client/types/document_chat_message_content.py +25 -0
- vellum/client/types/document_chat_message_content_request.py +25 -0
- vellum/client/types/document_vellum_value.py +25 -0
- vellum/client/types/document_vellum_value_request.py +25 -0
- vellum/client/types/paginated_workflow_sandbox_example_list.py +23 -0
- vellum/client/types/vellum_document.py +20 -0
- vellum/client/types/vellum_document_request.py +20 -0
- vellum/client/types/vellum_value.py +2 -0
- vellum/client/types/vellum_value_request.py +2 -0
- vellum/client/types/vellum_variable_type.py +1 -0
- vellum/client/types/workflow_sandbox_example.py +22 -0
- vellum/resources/workflow_sandboxes/types/__init__.py +3 -0
- vellum/resources/workflow_sandboxes/types/list_workflow_sandbox_examples_request_tag.py +3 -0
- vellum/types/document_chat_message_content.py +3 -0
- vellum/types/document_chat_message_content_request.py +3 -0
- vellum/types/document_vellum_value.py +3 -0
- vellum/types/document_vellum_value_request.py +3 -0
- vellum/types/paginated_workflow_sandbox_example_list.py +3 -0
- vellum/types/vellum_document.py +3 -0
- vellum/types/vellum_document_request.py +3 -0
- vellum/types/workflow_sandbox_example.py +3 -0
- vellum/workflows/exceptions.py +18 -0
- vellum/workflows/inputs/base.py +27 -1
- vellum/workflows/inputs/tests/__init__.py +0 -0
- vellum/workflows/inputs/tests/test_inputs.py +49 -0
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -1
- vellum/workflows/nodes/core/map_node/node.py +7 -7
- vellum/workflows/nodes/core/try_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +2 -2
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +5 -4
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +4 -4
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +49 -15
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +165 -0
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +3 -1
- vellum/workflows/outputs/base.py +1 -1
- vellum/workflows/runner/runner.py +16 -10
- vellum/workflows/state/context.py +7 -7
- vellum/workflows/workflows/base.py +61 -59
- vellum/workflows/workflows/tests/test_base_workflow.py +131 -40
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/RECORD +68 -44
- vellum_cli/__init__.py +36 -0
- vellum_cli/init.py +128 -0
- vellum_cli/pull.py +6 -3
- vellum_cli/tests/test_init.py +355 -0
- vellum_cli/tests/test_pull.py +127 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +4 -4
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +31 -0
- vellum_ee/workflows/display/nodes/vellum/utils.py +8 -0
- vellum_ee/workflows/display/vellum.py +0 -4
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +29 -0
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/entry_points.txt +0 -0
| @@ -3,6 +3,7 @@ from datetime import datetime | |
| 3 3 | 
             
            from uuid import uuid4
         | 
| 4 4 | 
             
            from typing import Any, Iterator, List
         | 
| 5 5 |  | 
| 6 | 
            +
            from vellum.client.core.api_error import ApiError
         | 
| 6 7 | 
             
            from vellum.client.types.chat_message import ChatMessage
         | 
| 7 8 | 
             
            from vellum.client.types.chat_message_request import ChatMessageRequest
         | 
| 8 9 | 
             
            from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
         | 
| @@ -11,6 +12,8 @@ from vellum.client.types.workflow_request_chat_history_input_request import Work | |
| 11 12 | 
             
            from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
         | 
| 12 13 | 
             
            from vellum.client.types.workflow_result_event import WorkflowResultEvent
         | 
| 13 14 | 
             
            from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
         | 
| 15 | 
            +
            from vellum.workflows.errors import WorkflowErrorCode
         | 
| 16 | 
            +
            from vellum.workflows.exceptions import NodeException
         | 
| 14 17 | 
             
            from vellum.workflows.nodes.displayable.subworkflow_deployment_node.node import SubworkflowDeploymentNode
         | 
| 15 18 |  | 
| 16 19 |  | 
| @@ -129,3 +132,165 @@ def test_run_workflow__any_array(vellum_client): | |
| 129 132 | 
             
                assert call_kwargs["inputs"] == [
         | 
| 130 133 | 
             
                    WorkflowRequestJsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
         | 
| 131 134 | 
             
                ]
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def test_run_workflow__no_deployment():
         | 
| 138 | 
            +
                """Confirm that we raise error when running a subworkflow deployment node with no deployment attribute set"""
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # GIVEN a Subworkflow Deployment Node
         | 
| 141 | 
            +
                class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
         | 
| 142 | 
            +
                    subworkflow_inputs = {
         | 
| 143 | 
            +
                        "fruits": ["apple", "banana", "cherry"],
         | 
| 144 | 
            +
                    }
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                # WHEN/THEN running the node should raise a NodeException
         | 
| 147 | 
            +
                node = ExampleSubworkflowDeploymentNode()
         | 
| 148 | 
            +
                with pytest.raises(NodeException) as exc_info:
         | 
| 149 | 
            +
                    list(node.run())
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                # AND the error message should be correct
         | 
| 152 | 
            +
                assert exc_info.value.code == WorkflowErrorCode.NODE_EXECUTION
         | 
| 153 | 
            +
                assert "Expected subworkflow deployment attribute to be either a UUID or STR, got None instead" in str(
         | 
| 154 | 
            +
                    exc_info.value
         | 
| 155 | 
            +
                )
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def test_run_workflow__hyphenated_output(vellum_client):
         | 
| 159 | 
            +
                """Confirm that we can successfully handle subworkflow outputs with hyphenated names"""
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # GIVEN a Subworkflow Deployment Node
         | 
| 162 | 
            +
                class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
         | 
| 163 | 
            +
                    deployment = "example_subworkflow_deployment"
         | 
| 164 | 
            +
                    subworkflow_inputs = {
         | 
| 165 | 
            +
                        "test_input": "test_value",
         | 
| 166 | 
            +
                    }
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    class Outputs(SubworkflowDeploymentNode.Outputs):
         | 
| 169 | 
            +
                        final_output_copy: str
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                # AND we know what the Subworkflow Deployment will respond with
         | 
| 172 | 
            +
                def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
         | 
| 173 | 
            +
                    execution_id = str(uuid4())
         | 
| 174 | 
            +
                    expected_events: List[WorkflowStreamEvent] = [
         | 
| 175 | 
            +
                        WorkflowExecutionWorkflowResultEvent(
         | 
| 176 | 
            +
                            execution_id=execution_id,
         | 
| 177 | 
            +
                            data=WorkflowResultEvent(
         | 
| 178 | 
            +
                                id=str(uuid4()),
         | 
| 179 | 
            +
                                state="INITIATED",
         | 
| 180 | 
            +
                                ts=datetime.now(),
         | 
| 181 | 
            +
                            ),
         | 
| 182 | 
            +
                        ),
         | 
| 183 | 
            +
                        WorkflowExecutionWorkflowResultEvent(
         | 
| 184 | 
            +
                            execution_id=execution_id,
         | 
| 185 | 
            +
                            data=WorkflowResultEvent(
         | 
| 186 | 
            +
                                id=str(uuid4()),
         | 
| 187 | 
            +
                                state="FULFILLED",
         | 
| 188 | 
            +
                                ts=datetime.now(),
         | 
| 189 | 
            +
                                outputs=[
         | 
| 190 | 
            +
                                    WorkflowOutputString(
         | 
| 191 | 
            +
                                        id=str(uuid4()),
         | 
| 192 | 
            +
                                        name="final-output_copy",  # Note the hyphen here
         | 
| 193 | 
            +
                                        value="test success",
         | 
| 194 | 
            +
                                    )
         | 
| 195 | 
            +
                                ],
         | 
| 196 | 
            +
                            ),
         | 
| 197 | 
            +
                        ),
         | 
| 198 | 
            +
                    ]
         | 
| 199 | 
            +
                    yield from expected_events
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                # WHEN we run the node
         | 
| 204 | 
            +
                node = ExampleSubworkflowDeploymentNode()
         | 
| 205 | 
            +
                events = list(node.run())
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                # THEN the node should have completed successfully
         | 
| 208 | 
            +
                assert events[-1].name == "final_output_copy"  # Note the underscore here
         | 
| 209 | 
            +
                assert events[-1].value == "test success"
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            @pytest.mark.parametrize(
         | 
| 213 | 
            +
                ["exception", "expected_code", "expected_message"],
         | 
| 214 | 
            +
                [
         | 
| 215 | 
            +
                    (
         | 
| 216 | 
            +
                        ApiError(status_code=400, body={"detail": "Missing required input variable: 'foo'"}),
         | 
| 217 | 
            +
                        WorkflowErrorCode.INVALID_INPUTS,
         | 
| 218 | 
            +
                        "Missing required input variable: 'foo'",
         | 
| 219 | 
            +
                    ),
         | 
| 220 | 
            +
                    (
         | 
| 221 | 
            +
                        ApiError(status_code=400, body={"message": "Missing required input variable: 'foo'"}),
         | 
| 222 | 
            +
                        WorkflowErrorCode.INVALID_INPUTS,
         | 
| 223 | 
            +
                        "Failed to execute Subworkflow Deployment",
         | 
| 224 | 
            +
                    ),
         | 
| 225 | 
            +
                    (
         | 
| 226 | 
            +
                        ApiError(status_code=400, body="Missing required input variable: 'foo'"),
         | 
| 227 | 
            +
                        WorkflowErrorCode.INTERNAL_ERROR,
         | 
| 228 | 
            +
                        "Failed to execute Subworkflow Deployment",
         | 
| 229 | 
            +
                    ),
         | 
| 230 | 
            +
                    (
         | 
| 231 | 
            +
                        ApiError(status_code=None, body={"detail": "Missing required input variable: 'foo'"}),
         | 
| 232 | 
            +
                        WorkflowErrorCode.INTERNAL_ERROR,
         | 
| 233 | 
            +
                        "Failed to execute Subworkflow Deployment",
         | 
| 234 | 
            +
                    ),
         | 
| 235 | 
            +
                    (
         | 
| 236 | 
            +
                        ApiError(status_code=500, body={"detail": "Missing required input variable: 'foo'"}),
         | 
| 237 | 
            +
                        WorkflowErrorCode.INTERNAL_ERROR,
         | 
| 238 | 
            +
                        "Failed to execute Subworkflow Deployment",
         | 
| 239 | 
            +
                    ),
         | 
| 240 | 
            +
                ],
         | 
| 241 | 
            +
                ids=["400", "invalid_dict", "invalid_body", "no_status_code", "500"],
         | 
| 242 | 
            +
            )
         | 
| 243 | 
            +
            def test_subworkflow_deployment_node__api_error__invalid_inputs_node_exception(
         | 
| 244 | 
            +
                vellum_client, exception, expected_code, expected_message
         | 
| 245 | 
            +
            ):
         | 
| 246 | 
            +
                # GIVEN a prompt node with an invalid model name
         | 
| 247 | 
            +
                class MyNode(SubworkflowDeploymentNode):
         | 
| 248 | 
            +
                    deployment = "example_subworkflow_deployment"
         | 
| 249 | 
            +
                    subworkflow_inputs = {
         | 
| 250 | 
            +
                        "not_foo": "bar",
         | 
| 251 | 
            +
                    }
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                # AND the Subworkflow Deployment API call fails
         | 
| 254 | 
            +
                def _side_effect(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
         | 
| 255 | 
            +
                    if kwargs.get("_mock_condition_to_induce_an_error"):
         | 
| 256 | 
            +
                        yield WorkflowExecutionWorkflowResultEvent(
         | 
| 257 | 
            +
                            execution_id=str(uuid4()),
         | 
| 258 | 
            +
                            data=WorkflowResultEvent(
         | 
| 259 | 
            +
                                id=str(uuid4()),
         | 
| 260 | 
            +
                                state="INITIATED",
         | 
| 261 | 
            +
                                ts=datetime.now(),
         | 
| 262 | 
            +
                            ),
         | 
| 263 | 
            +
                        )
         | 
| 264 | 
            +
                    else:
         | 
| 265 | 
            +
                        raise exception
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                # AND the vellum client execute workflow stream raises a 4xx error
         | 
| 268 | 
            +
                vellum_client.execute_workflow_stream.side_effect = _side_effect
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                # WHEN the node is run
         | 
| 271 | 
            +
                with pytest.raises(NodeException) as e:
         | 
| 272 | 
            +
                    list(MyNode().run())
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                # THEN the node raises the correct NodeException
         | 
| 275 | 
            +
                assert e.value.code == expected_code
         | 
| 276 | 
            +
                assert e.value.message == expected_message
         | 
| 277 | 
            +
             | 
| 278 | 
            +
             | 
| 279 | 
            +
            def test_subworkflow_deployment_node__immediate_api_error__node_exception(vellum_client):
         | 
| 280 | 
            +
                # GIVEN a prompt node with an invalid model name
         | 
| 281 | 
            +
                class MyNode(SubworkflowDeploymentNode):
         | 
| 282 | 
            +
                    deployment = "example_subworkflow_deployment"
         | 
| 283 | 
            +
                    subworkflow_inputs = {
         | 
| 284 | 
            +
                        "not_foo": "bar",
         | 
| 285 | 
            +
                    }
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                # AND the vellum client execute workflow stream raises a 4xx error
         | 
| 288 | 
            +
                vellum_client.execute_workflow_stream.side_effect = ApiError(status_code=404, body={"detail": "Not found"})
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                # WHEN the node is run
         | 
| 291 | 
            +
                with pytest.raises(NodeException) as e:
         | 
| 292 | 
            +
                    list(MyNode().run())
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                # THEN the node raises the correct NodeException
         | 
| 295 | 
            +
                assert e.value.code == WorkflowErrorCode.INVALID_INPUTS
         | 
| 296 | 
            +
                assert e.value.message == "Not found"
         | 
| @@ -74,5 +74,7 @@ def test_text_prompt_deployment_node__basic(vellum_client): | |
| 74 74 | 
             
                    prompt_deployment_name="my-deployment",
         | 
| 75 75 | 
             
                    raw_overrides=OMIT,
         | 
| 76 76 | 
             
                    release_tag="LATEST",
         | 
| 77 | 
            -
                    request_options={ | 
| 77 | 
            +
                    request_options={
         | 
| 78 | 
            +
                        "additional_body_parameters": {"execution_context": {"parent_context": None, "trace_id": None}}
         | 
| 79 | 
            +
                    },
         | 
| 78 80 | 
             
                )
         | 
    
        vellum/workflows/outputs/base.py
    CHANGED
    
    | @@ -32,7 +32,7 @@ class BaseOutput(Generic[_Delta, _Accumulated]): | |
| 32 32 | 
             
                    if value is not undefined and delta is not undefined:
         | 
| 33 33 | 
             
                        raise ValueError("Cannot set both value and delta")
         | 
| 34 34 |  | 
| 35 | 
            -
                    self._name = name
         | 
| 35 | 
            +
                    self._name = name.replace("-", "_")  # Convert hyphens to underscores for valid python variable names
         | 
| 36 36 | 
             
                    self._value = value
         | 
| 37 37 | 
             
                    self._delta = delta
         | 
| 38 38 |  | 
| @@ -7,7 +7,7 @@ from uuid import UUID | |
| 7 7 | 
             
            from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
         | 
| 8 8 |  | 
| 9 9 | 
             
            from vellum.workflows.constants import undefined
         | 
| 10 | 
            -
            from vellum.workflows.context import execution_context, get_parent_context
         | 
| 10 | 
            +
            from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context, get_parent_context
         | 
| 11 11 | 
             
            from vellum.workflows.descriptors.base import BaseDescriptor
         | 
| 12 12 | 
             
            from vellum.workflows.edges.edge import Edge
         | 
| 13 13 | 
             
            from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
         | 
| @@ -29,7 +29,7 @@ from vellum.workflows.events.node import ( | |
| 29 29 | 
             
                NodeExecutionRejectedBody,
         | 
| 30 30 | 
             
                NodeExecutionStreamingBody,
         | 
| 31 31 | 
             
            )
         | 
| 32 | 
            -
            from vellum.workflows.events.types import BaseEvent, NodeParentContext,  | 
| 32 | 
            +
            from vellum.workflows.events.types import BaseEvent, NodeParentContext, WorkflowParentContext
         | 
| 33 33 | 
             
            from vellum.workflows.events.workflow import (
         | 
| 34 34 | 
             
                WorkflowExecutionFulfilledBody,
         | 
| 35 35 | 
             
                WorkflowExecutionInitiatedBody,
         | 
| @@ -75,8 +75,8 @@ class WorkflowRunner(Generic[StateType]): | |
| 75 75 | 
             
                    external_inputs: Optional[ExternalInputsArg] = None,
         | 
| 76 76 | 
             
                    cancel_signal: Optional[ThreadingEvent] = None,
         | 
| 77 77 | 
             
                    node_output_mocks: Optional[MockNodeExecutionArg] = None,
         | 
| 78 | 
            -
                    parent_context: Optional[ParentContext] = None,
         | 
| 79 78 | 
             
                    max_concurrency: Optional[int] = None,
         | 
| 79 | 
            +
                    init_execution_context: Optional[ExecutionContext] = None,
         | 
| 80 80 | 
             
                ):
         | 
| 81 81 | 
             
                    if state and external_inputs:
         | 
| 82 82 | 
             
                        raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
         | 
| @@ -98,6 +98,11 @@ class WorkflowRunner(Generic[StateType]): | |
| 98 98 | 
             
                    elif external_inputs:
         | 
| 99 99 | 
             
                        self._initial_state = self.workflow.get_most_recent_state()
         | 
| 100 100 | 
             
                        for descriptor, value in external_inputs.items():
         | 
| 101 | 
            +
                            if not any(isinstance(value, type_) for type_ in descriptor.types):
         | 
| 102 | 
            +
                                raise NodeException(
         | 
| 103 | 
            +
                                    f"Invalid external input type for {descriptor.name}",
         | 
| 104 | 
            +
                                    code=WorkflowErrorCode.INVALID_INPUTS,
         | 
| 105 | 
            +
                                )
         | 
| 101 106 | 
             
                            self._initial_state.meta.external_inputs[descriptor] = value
         | 
| 102 107 |  | 
| 103 108 | 
             
                        self._entrypoints = [
         | 
| @@ -133,7 +138,8 @@ class WorkflowRunner(Generic[StateType]): | |
| 133 138 |  | 
| 134 139 | 
             
                    self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
         | 
| 135 140 | 
             
                    self._cancel_signal = cancel_signal
         | 
| 136 | 
            -
                    self. | 
| 141 | 
            +
                    self._execution_context = init_execution_context or get_execution_context()
         | 
| 142 | 
            +
                    self._parent_context = self._execution_context.parent_context
         | 
| 137 143 |  | 
| 138 144 | 
             
                    setattr(
         | 
| 139 145 | 
             
                        self._initial_state,
         | 
| @@ -196,7 +202,7 @@ class WorkflowRunner(Generic[StateType]): | |
| 196 202 | 
             
                                break
         | 
| 197 203 |  | 
| 198 204 | 
             
                        if not was_mocked:
         | 
| 199 | 
            -
                            with execution_context(parent_context=updated_parent_context):
         | 
| 205 | 
            +
                            with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
         | 
| 200 206 | 
             
                                node_run_response = node.run()
         | 
| 201 207 |  | 
| 202 208 | 
             
                        ports = node.Ports()
         | 
| @@ -243,7 +249,7 @@ class WorkflowRunner(Generic[StateType]): | |
| 243 249 | 
             
                                    ),
         | 
| 244 250 | 
             
                                )
         | 
| 245 251 |  | 
| 246 | 
            -
                            with execution_context(parent_context=updated_parent_context):
         | 
| 252 | 
            +
                            with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
         | 
| 247 253 | 
             
                                for output in node_run_response:
         | 
| 248 254 | 
             
                                    invoked_ports = output > ports
         | 
| 249 255 | 
             
                                    if output.is_initiated:
         | 
| @@ -346,7 +352,7 @@ class WorkflowRunner(Generic[StateType]): | |
| 346 352 | 
             
                    if parent_context is None:
         | 
| 347 353 | 
             
                        parent_context = get_parent_context() or self._parent_context
         | 
| 348 354 |  | 
| 349 | 
            -
                    with execution_context(parent_context=parent_context):
         | 
| 355 | 
            +
                    with execution_context(parent_context=parent_context, trace_id=node.state.meta.trace_id):
         | 
| 350 356 | 
             
                        self._run_work_item(node, span_id)
         | 
| 351 357 |  | 
| 352 358 | 
             
                def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
         | 
| @@ -524,7 +530,7 @@ class WorkflowRunner(Generic[StateType]): | |
| 524 530 | 
             
                    for node_cls in self._entrypoints:
         | 
| 525 531 | 
             
                        try:
         | 
| 526 532 | 
             
                            if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
         | 
| 527 | 
            -
                                with execution_context(parent_context=current_parent):
         | 
| 533 | 
            +
                                with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
         | 
| 528 534 | 
             
                                    self._run_node_if_ready(self._initial_state, node_cls)
         | 
| 529 535 | 
             
                            else:
         | 
| 530 536 | 
             
                                self._concurrency_queue.put((self._initial_state, node_cls, None))
         | 
| @@ -551,7 +557,7 @@ class WorkflowRunner(Generic[StateType]): | |
| 551 557 |  | 
| 552 558 | 
             
                        self._workflow_event_outer_queue.put(event)
         | 
| 553 559 |  | 
| 554 | 
            -
                        with execution_context(parent_context=current_parent):
         | 
| 560 | 
            +
                        with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
         | 
| 555 561 | 
             
                            rejection_error = self._handle_work_item_event(event)
         | 
| 556 562 |  | 
| 557 563 | 
             
                        if rejection_error:
         | 
| @@ -562,7 +568,7 @@ class WorkflowRunner(Generic[StateType]): | |
| 562 568 | 
             
                        while event := self._workflow_event_inner_queue.get_nowait():
         | 
| 563 569 | 
             
                            self._workflow_event_outer_queue.put(event)
         | 
| 564 570 |  | 
| 565 | 
            -
                            with execution_context(parent_context=current_parent):
         | 
| 571 | 
            +
                            with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
         | 
| 566 572 | 
             
                                rejection_error = self._handle_work_item_event(event)
         | 
| 567 573 |  | 
| 568 574 | 
             
                            if rejection_error:
         | 
| @@ -3,7 +3,7 @@ from queue import Queue | |
| 3 3 | 
             
            from typing import TYPE_CHECKING, Dict, List, Optional, Type
         | 
| 4 4 |  | 
| 5 5 | 
             
            from vellum import Vellum
         | 
| 6 | 
            -
            from vellum.workflows. | 
| 6 | 
            +
            from vellum.workflows.context import ExecutionContext, get_execution_context
         | 
| 7 7 | 
             
            from vellum.workflows.nodes.mocks import MockNodeExecution, MockNodeExecutionArg
         | 
| 8 8 | 
             
            from vellum.workflows.outputs.base import BaseOutputs
         | 
| 9 9 | 
             
            from vellum.workflows.references.constant import ConstantValueReference
         | 
| @@ -18,12 +18,14 @@ class WorkflowContext: | |
| 18 18 | 
             
                    self,
         | 
| 19 19 | 
             
                    *,
         | 
| 20 20 | 
             
                    vellum_client: Optional[Vellum] = None,
         | 
| 21 | 
            -
                     | 
| 21 | 
            +
                    execution_context: Optional[ExecutionContext] = None,
         | 
| 22 22 | 
             
                ):
         | 
| 23 23 | 
             
                    self._vellum_client = vellum_client
         | 
| 24 | 
            -
                    self._parent_context = parent_context
         | 
| 25 24 | 
             
                    self._event_queue: Optional[Queue["WorkflowEvent"]] = None
         | 
| 26 25 | 
             
                    self._node_output_mocks_map: Dict[Type[BaseOutputs], List[MockNodeExecution]] = {}
         | 
| 26 | 
            +
                    self._execution_context = get_execution_context()
         | 
| 27 | 
            +
                    if not self._execution_context.parent_context and execution_context:
         | 
| 28 | 
            +
                        self._execution_context = execution_context
         | 
| 27 29 |  | 
| 28 30 | 
             
                @cached_property
         | 
| 29 31 | 
             
                def vellum_client(self) -> Vellum:
         | 
| @@ -33,10 +35,8 @@ class WorkflowContext: | |
| 33 35 | 
             
                    return create_vellum_client()
         | 
| 34 36 |  | 
| 35 37 | 
             
                @cached_property
         | 
| 36 | 
            -
                def  | 
| 37 | 
            -
                     | 
| 38 | 
            -
                        return self._parent_context
         | 
| 39 | 
            -
                    return None
         | 
| 38 | 
            +
                def execution_context(self) -> ExecutionContext:
         | 
| 39 | 
            +
                    return self._execution_context
         | 
| 40 40 |  | 
| 41 41 | 
             
                @cached_property
         | 
| 42 42 | 
             
                def node_output_mocks_map(self) -> Dict[Type[BaseOutputs], List[MockNodeExecution]]:
         | 
| @@ -24,6 +24,7 @@ from typing import ( | |
| 24 24 | 
             
                get_args,
         | 
| 25 25 | 
             
            )
         | 
| 26 26 |  | 
| 27 | 
            +
            from vellum.workflows.context import get_execution_context
         | 
| 27 28 | 
             
            from vellum.workflows.edges import Edge
         | 
| 28 29 | 
             
            from vellum.workflows.emitters.base import BaseWorkflowEmitter
         | 
| 29 30 | 
             
            from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
         | 
| @@ -171,6 +172,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta): | |
| 171 172 | 
             
                    self.resolvers = resolvers or (self.resolvers if hasattr(self, "resolvers") else [])
         | 
| 172 173 | 
             
                    self._context = context or WorkflowContext()
         | 
| 173 174 | 
             
                    self._store = Store()
         | 
| 175 | 
            +
                    self._execution_context = self._context.execution_context
         | 
| 174 176 |  | 
| 175 177 | 
             
                    self.validate()
         | 
| 176 178 |  | 
| @@ -178,48 +180,64 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta): | |
| 178 180 | 
             
                def context(self) -> WorkflowContext:
         | 
| 179 181 | 
             
                    return self._context
         | 
| 180 182 |  | 
| 181 | 
            -
                @ | 
| 182 | 
            -
                def  | 
| 183 | 
            -
                    original_graph = cls.graph
         | 
| 184 | 
            -
                    if isinstance(original_graph, Graph):
         | 
| 185 | 
            -
                        return [original_graph]
         | 
| 186 | 
            -
                    if isinstance(original_graph, set):
         | 
| 187 | 
            -
                        return [
         | 
| 188 | 
            -
                            subgraph if isinstance(subgraph, Graph) else Graph.from_node(subgraph) for subgraph in original_graph
         | 
| 189 | 
            -
                        ]
         | 
| 190 | 
            -
                    if issubclass(original_graph, BaseNode):
         | 
| 191 | 
            -
                        return [Graph.from_node(original_graph)]
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                    raise ValueError(f"Unexpected graph type: {original_graph.__class__}")
         | 
| 194 | 
            -
             | 
| 195 | 
            -
                @classmethod
         | 
| 196 | 
            -
                def get_edges(cls) -> Iterator[Edge]:
         | 
| 183 | 
            +
                @staticmethod
         | 
| 184 | 
            +
                def _resolve_graph(graph: GraphAttribute) -> List[Graph]:
         | 
| 197 185 | 
             
                    """
         | 
| 198 | 
            -
                     | 
| 199 | 
            -
                    ensure uniqueness, and the iterator to preserve order.
         | 
| 186 | 
            +
                    Resolves a single graph source to a list of Graph objects.
         | 
| 200 187 | 
             
                    """
         | 
| 188 | 
            +
                    if isinstance(graph, Graph):
         | 
| 189 | 
            +
                        return [graph]
         | 
| 190 | 
            +
                    if isinstance(graph, set):
         | 
| 191 | 
            +
                        graphs = []
         | 
| 192 | 
            +
                        for item in graph:
         | 
| 193 | 
            +
                            if isinstance(item, Graph):
         | 
| 194 | 
            +
                                graphs.append(item)
         | 
| 195 | 
            +
                            elif issubclass(item, BaseNode):
         | 
| 196 | 
            +
                                graphs.append(Graph.from_node(item))
         | 
| 197 | 
            +
                            else:
         | 
| 198 | 
            +
                                raise ValueError(f"Unexpected graph type: {type(item)}")
         | 
| 199 | 
            +
                        return graphs
         | 
| 200 | 
            +
                    if issubclass(graph, BaseNode):
         | 
| 201 | 
            +
                        return [Graph.from_node(graph)]
         | 
| 202 | 
            +
                    raise ValueError(f"Unexpected graph type: {type(graph)}")
         | 
| 201 203 |  | 
| 204 | 
            +
                @staticmethod
         | 
| 205 | 
            +
                def _get_edges_from_subgraphs(subgraphs: Iterable[Graph]) -> Iterator[Edge]:
         | 
| 202 206 | 
             
                    edges = set()
         | 
| 203 | 
            -
                    subgraphs = cls.get_subgraphs()
         | 
| 204 207 | 
             
                    for subgraph in subgraphs:
         | 
| 205 208 | 
             
                        for edge in subgraph.edges:
         | 
| 206 209 | 
             
                            if edge not in edges:
         | 
| 207 210 | 
             
                                edges.add(edge)
         | 
| 208 211 | 
             
                                yield edge
         | 
| 209 212 |  | 
| 213 | 
            +
                @staticmethod
         | 
| 214 | 
            +
                def _get_nodes_from_subgraphs(subgraphs: Iterable[Graph]) -> Iterator[Type[BaseNode]]:
         | 
| 215 | 
            +
                    nodes = set()
         | 
| 216 | 
            +
                    for subgraph in subgraphs:
         | 
| 217 | 
            +
                        for node in subgraph.nodes:
         | 
| 218 | 
            +
                            if node not in nodes:
         | 
| 219 | 
            +
                                nodes.add(node)
         | 
| 220 | 
            +
                                yield node
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                @classmethod
         | 
| 223 | 
            +
                def get_subgraphs(cls) -> List[Graph]:
         | 
| 224 | 
            +
                    return cls._resolve_graph(cls.graph)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                @classmethod
         | 
| 227 | 
            +
                def get_edges(cls) -> Iterator[Edge]:
         | 
| 228 | 
            +
                    """
         | 
| 229 | 
            +
                    Returns an iterator over the edges in the workflow. We use a set to
         | 
| 230 | 
            +
                    ensure uniqueness, and the iterator to preserve order.
         | 
| 231 | 
            +
                    """
         | 
| 232 | 
            +
                    return cls._get_edges_from_subgraphs(cls.get_subgraphs())
         | 
| 233 | 
            +
             | 
| 210 234 | 
             
                @classmethod
         | 
| 211 235 | 
             
                def get_nodes(cls) -> Iterator[Type[BaseNode]]:
         | 
| 212 236 | 
             
                    """
         | 
| 213 237 | 
             
                    Returns an iterator over the nodes in the workflow. We use a set to
         | 
| 214 238 | 
             
                    ensure uniqueness, and the iterator to preserve order.
         | 
| 215 239 | 
             
                    """
         | 
| 216 | 
            -
             | 
| 217 | 
            -
                    nodes = set()
         | 
| 218 | 
            -
                    for subgraph in cls.get_subgraphs():
         | 
| 219 | 
            -
                        for node in subgraph.nodes:
         | 
| 220 | 
            -
                            if node not in nodes:
         | 
| 221 | 
            -
                                nodes.add(node)
         | 
| 222 | 
            -
                                yield node
         | 
| 240 | 
            +
                    return cls._get_nodes_from_subgraphs(cls.get_subgraphs())
         | 
| 223 241 |  | 
| 224 242 | 
             
                @classmethod
         | 
| 225 243 | 
             
                def get_unused_subgraphs(cls) -> List[Graph]:
         | 
| @@ -228,19 +246,9 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta): | |
| 228 246 | 
             
                    """
         | 
| 229 247 | 
             
                    if not hasattr(cls, "unused_graphs"):
         | 
| 230 248 | 
             
                        return []
         | 
| 231 | 
            -
             | 
| 232 249 | 
             
                    graphs = []
         | 
| 233 250 | 
             
                    for item in cls.unused_graphs:
         | 
| 234 | 
            -
                         | 
| 235 | 
            -
                            graphs.append(item)
         | 
| 236 | 
            -
                        elif isinstance(item, set):
         | 
| 237 | 
            -
                            for subitem in item:
         | 
| 238 | 
            -
                                if isinstance(subitem, Graph):
         | 
| 239 | 
            -
                                    graphs.append(subitem)
         | 
| 240 | 
            -
                                elif issubclass(subitem, BaseNode):
         | 
| 241 | 
            -
                                    graphs.append(Graph.from_node(subitem))
         | 
| 242 | 
            -
                        elif issubclass(item, BaseNode):
         | 
| 243 | 
            -
                            graphs.append(Graph.from_node(item))
         | 
| 251 | 
            +
                        graphs.extend(cls._resolve_graph(item))
         | 
| 244 252 | 
             
                    return graphs
         | 
| 245 253 |  | 
| 246 254 | 
             
                @classmethod
         | 
| @@ -248,29 +256,14 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta): | |
| 248 256 | 
             
                    """
         | 
| 249 257 | 
             
                    Returns an iterator over the nodes that are defined but not used in the graph.
         | 
| 250 258 | 
             
                    """
         | 
| 251 | 
            -
                     | 
| 252 | 
            -
                        yield from ()
         | 
| 253 | 
            -
                    else:
         | 
| 254 | 
            -
                        nodes = set()
         | 
| 255 | 
            -
                        subgraphs = cls.get_unused_subgraphs()
         | 
| 256 | 
            -
                        for subgraph in subgraphs:
         | 
| 257 | 
            -
                            for node in subgraph.nodes:
         | 
| 258 | 
            -
                                if node not in nodes:
         | 
| 259 | 
            -
                                    nodes.add(node)
         | 
| 260 | 
            -
                                    yield node
         | 
| 259 | 
            +
                    return cls._get_nodes_from_subgraphs(cls.get_unused_subgraphs())
         | 
| 261 260 |  | 
| 262 261 | 
             
                @classmethod
         | 
| 263 262 | 
             
                def get_unused_edges(cls) -> Iterator[Edge]:
         | 
| 264 263 | 
             
                    """
         | 
| 265 264 | 
             
                    Returns an iterator over edges that are defined but not used in the graph.
         | 
| 266 265 | 
             
                    """
         | 
| 267 | 
            -
                     | 
| 268 | 
            -
                    subgraphs = cls.get_unused_subgraphs()
         | 
| 269 | 
            -
                    for subgraph in subgraphs:
         | 
| 270 | 
            -
                        for edge in subgraph.edges:
         | 
| 271 | 
            -
                            if edge not in edges:
         | 
| 272 | 
            -
                                edges.add(edge)
         | 
| 273 | 
            -
                                yield edge
         | 
| 266 | 
            +
                    return cls._get_edges_from_subgraphs(cls.get_unused_subgraphs())
         | 
| 274 267 |  | 
| 275 268 | 
             
                @classmethod
         | 
| 276 269 | 
             
                def get_entrypoints(cls) -> Iterable[Type[BaseNode]]:
         | 
| @@ -329,8 +322,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta): | |
| 329 322 | 
             
                        external_inputs=external_inputs,
         | 
| 330 323 | 
             
                        cancel_signal=cancel_signal,
         | 
| 331 324 | 
             
                        node_output_mocks=node_output_mocks,
         | 
| 332 | 
            -
                        parent_context=self._context.parent_context,
         | 
| 333 325 | 
             
                        max_concurrency=max_concurrency,
         | 
| 326 | 
            +
                        init_execution_context=self._execution_context,
         | 
| 334 327 | 
             
                    ).stream()
         | 
| 335 328 | 
             
                    first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
         | 
| 336 329 | 
             
                    last_event = None
         | 
| @@ -440,8 +433,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta): | |
| 440 433 | 
             
                        external_inputs=external_inputs,
         | 
| 441 434 | 
             
                        cancel_signal=cancel_signal,
         | 
| 442 435 | 
             
                        node_output_mocks=node_output_mocks,
         | 
| 443 | 
            -
                        parent_context=self.context.parent_context,
         | 
| 444 436 | 
             
                        max_concurrency=max_concurrency,
         | 
| 437 | 
            +
                        init_execution_context=self._execution_context,
         | 
| 445 438 | 
             
                    ).stream():
         | 
| 446 439 | 
             
                        if should_yield(self.__class__, event):
         | 
| 447 440 | 
             
                            yield event
         | 
| @@ -488,10 +481,19 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta): | |
| 488 481 | 
             
                    return self.get_inputs_class()()
         | 
| 489 482 |  | 
| 490 483 | 
             
                def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
         | 
| 484 | 
            +
                    execution_context = get_execution_context()
         | 
| 491 485 | 
             
                    return self.get_state_class()(
         | 
| 492 | 
            -
                        meta= | 
| 493 | 
            -
                             | 
| 494 | 
            -
             | 
| 486 | 
            +
                        meta=(
         | 
| 487 | 
            +
                            StateMeta(
         | 
| 488 | 
            +
                                parent=self._parent_state,
         | 
| 489 | 
            +
                                workflow_inputs=workflow_inputs or self.get_default_inputs(),
         | 
| 490 | 
            +
                                trace_id=execution_context.trace_id,
         | 
| 491 | 
            +
                            )
         | 
| 492 | 
            +
                            if execution_context and execution_context.trace_id
         | 
| 493 | 
            +
                            else StateMeta(
         | 
| 494 | 
            +
                                parent=self._parent_state,
         | 
| 495 | 
            +
                                workflow_inputs=workflow_inputs or self.get_default_inputs(),
         | 
| 496 | 
            +
                            )
         | 
| 495 497 | 
             
                        )
         | 
| 496 498 | 
             
                    )
         | 
| 497 499 |  |