vellum-ai 0.14.6__py3-none-any.whl → 0.14.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. vellum/__init__.py +12 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +12 -0
  4. vellum/client/types/array_chat_message_content_item.py +6 -1
  5. vellum/client/types/array_chat_message_content_item_request.py +2 -0
  6. vellum/client/types/chat_message_content.py +2 -0
  7. vellum/client/types/chat_message_content_request.py +2 -0
  8. vellum/client/types/document_chat_message_content.py +25 -0
  9. vellum/client/types/document_chat_message_content_request.py +25 -0
  10. vellum/client/types/document_vellum_value.py +25 -0
  11. vellum/client/types/document_vellum_value_request.py +25 -0
  12. vellum/client/types/vellum_document.py +20 -0
  13. vellum/client/types/vellum_document_request.py +20 -0
  14. vellum/client/types/vellum_value.py +2 -0
  15. vellum/client/types/vellum_value_request.py +2 -0
  16. vellum/client/types/vellum_variable_type.py +1 -0
  17. vellum/types/document_chat_message_content.py +3 -0
  18. vellum/types/document_chat_message_content_request.py +3 -0
  19. vellum/types/document_vellum_value.py +3 -0
  20. vellum/types/document_vellum_value_request.py +3 -0
  21. vellum/types/vellum_document.py +3 -0
  22. vellum/types/vellum_document_request.py +3 -0
  23. vellum/workflows/exceptions.py +18 -0
  24. vellum/workflows/inputs/base.py +27 -1
  25. vellum/workflows/inputs/tests/__init__.py +0 -0
  26. vellum/workflows/inputs/tests/test_inputs.py +49 -0
  27. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -1
  28. vellum/workflows/nodes/core/map_node/node.py +7 -7
  29. vellum/workflows/nodes/core/try_node/node.py +1 -1
  30. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +2 -2
  31. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
  32. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +5 -4
  33. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +4 -4
  34. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +39 -15
  35. vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +142 -0
  36. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +3 -1
  37. vellum/workflows/outputs/base.py +1 -1
  38. vellum/workflows/runner/runner.py +16 -10
  39. vellum/workflows/state/context.py +7 -7
  40. vellum/workflows/workflows/base.py +16 -5
  41. vellum/workflows/workflows/tests/test_base_workflow.py +131 -40
  42. {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/METADATA +1 -1
  43. {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/RECORD +54 -38
  44. vellum_cli/__init__.py +36 -0
  45. vellum_cli/init.py +128 -0
  46. vellum_cli/pull.py +6 -3
  47. vellum_cli/tests/test_init.py +355 -0
  48. vellum_cli/tests/test_pull.py +127 -0
  49. vellum_ee/workflows/display/nodes/base_node_display.py +4 -4
  50. vellum_ee/workflows/display/vellum.py +0 -4
  51. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +29 -0
  52. {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/LICENSE +0 -0
  53. {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/WHEEL +0 -0
  54. {vellum_ai-0.14.6.dist-info → vellum_ai-0.14.7.dist-info}/entry_points.txt +0 -0
@@ -146,22 +146,22 @@ def test_inline_prompt_node__function_definitions(vellum_adhoc_prompt_client):
146
146
  (
147
147
  ApiError(status_code=404, body={"message": "Model not found"}),
148
148
  WorkflowErrorCode.INVALID_INPUTS,
149
- "Failed to execute prompt",
149
+ "Failed to execute Prompt",
150
150
  ),
151
151
  (
152
152
  ApiError(status_code=404, body="Model not found"),
153
153
  WorkflowErrorCode.INTERNAL_ERROR,
154
- "Failed to execute prompt",
154
+ "Failed to execute Prompt",
155
155
  ),
156
156
  (
157
157
  ApiError(status_code=None, body={"detail": "Model not found"}),
158
158
  WorkflowErrorCode.INTERNAL_ERROR,
159
- "Failed to execute prompt",
159
+ "Failed to execute Prompt",
160
160
  ),
161
161
  (
162
162
  ApiError(status_code=500, body={"detail": "Model not found"}),
163
163
  WorkflowErrorCode.INTERNAL_ERROR,
164
- "Failed to execute prompt",
164
+ "Failed to execute Prompt",
165
165
  ),
166
166
  ],
167
167
  ids=["404", "invalid_dict", "invalid_body", "no_status_code", "500"],
@@ -12,10 +12,11 @@ from vellum import (
12
12
  WorkflowRequestNumberInputRequest,
13
13
  WorkflowRequestStringInputRequest,
14
14
  )
15
+ from vellum.client.core.api_error import ApiError
15
16
  from vellum.client.types.chat_message_request import ChatMessageRequest
16
17
  from vellum.core import RequestOptions
17
18
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
18
- from vellum.workflows.context import get_parent_context
19
+ from vellum.workflows.context import get_execution_context
19
20
  from vellum.workflows.errors import WorkflowErrorCode
20
21
  from vellum.workflows.errors.types import workflow_event_error_to_workflow_error
21
22
  from vellum.workflows.events.types import default_serializer
@@ -120,11 +121,13 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
120
121
  return compiled_inputs
121
122
 
122
123
  def run(self) -> Iterator[BaseOutput]:
123
- current_parent_context = get_parent_context()
124
- parent_context = current_parent_context.model_dump(mode="json") if current_parent_context else None
124
+ current_context = get_execution_context()
125
+ parent_context = (
126
+ current_context.parent_context.model_dump(mode="json") if current_context.parent_context else None
127
+ )
125
128
  request_options = self.request_options or RequestOptions()
126
129
  request_options["additional_body_parameters"] = {
127
- "execution_context": {"parent_context": parent_context},
130
+ "execution_context": {"parent_context": parent_context, "trace_id": current_context.trace_id},
128
131
  **request_options.get("additional_body_parameters", {}),
129
132
  }
130
133
 
@@ -137,17 +140,26 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
137
140
  message="Expected subworkflow deployment attribute to be either a UUID or STR, got None instead",
138
141
  )
139
142
 
140
- subworkflow_stream = self._context.vellum_client.execute_workflow_stream(
141
- inputs=self._compile_subworkflow_inputs(),
142
- workflow_deployment_id=deployment_id,
143
- workflow_deployment_name=deployment_name,
144
- release_tag=self.release_tag,
145
- external_id=self.external_id,
146
- event_types=["WORKFLOW"],
147
- metadata=self.metadata,
148
- request_options=request_options,
149
- )
150
- # for some reason execution context isn't showing as an option? ^ failing mypy
143
+ try:
144
+ subworkflow_stream = self._context.vellum_client.execute_workflow_stream(
145
+ inputs=self._compile_subworkflow_inputs(),
146
+ workflow_deployment_id=deployment_id,
147
+ workflow_deployment_name=deployment_name,
148
+ release_tag=self.release_tag,
149
+ external_id=self.external_id,
150
+ event_types=["WORKFLOW"],
151
+ metadata=self.metadata,
152
+ request_options=request_options,
153
+ )
154
+ except ApiError as e:
155
+ self._handle_api_error(e)
156
+
157
+ # We don't use the INITIATED event anyway, so we can just skip it
158
+ # and use the exception handling to catch other api level errors
159
+ try:
160
+ next(subworkflow_stream)
161
+ except ApiError as e:
162
+ self._handle_api_error(e)
151
163
 
152
164
  outputs: Optional[List[WorkflowOutput]] = None
153
165
  fulfilled_output_names: Set[str] = set()
@@ -195,3 +207,15 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
195
207
  name=output.name,
196
208
  value=output.value,
197
209
  )
210
+
211
+ def _handle_api_error(self, e: ApiError):
212
+ if e.status_code and e.status_code >= 400 and e.status_code < 500 and isinstance(e.body, dict):
213
+ raise NodeException(
214
+ message=e.body.get("detail", "Failed to execute Subworkflow Deployment"),
215
+ code=WorkflowErrorCode.INVALID_INPUTS,
216
+ ) from e
217
+
218
+ raise NodeException(
219
+ message="Failed to execute Subworkflow Deployment",
220
+ code=WorkflowErrorCode.INTERNAL_ERROR,
221
+ ) from e
@@ -3,6 +3,7 @@ from datetime import datetime
3
3
  from uuid import uuid4
4
4
  from typing import Any, Iterator, List
5
5
 
6
+ from vellum.client.core.api_error import ApiError
6
7
  from vellum.client.types.chat_message import ChatMessage
7
8
  from vellum.client.types.chat_message_request import ChatMessageRequest
8
9
  from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
@@ -152,3 +153,144 @@ def test_run_workflow__no_deployment():
152
153
  assert "Expected subworkflow deployment attribute to be either a UUID or STR, got None instead" in str(
153
154
  exc_info.value
154
155
  )
156
+
157
+
158
+ def test_run_workflow__hyphenated_output(vellum_client):
159
+ """Confirm that we can successfully handle subworkflow outputs with hyphenated names"""
160
+
161
+ # GIVEN a Subworkflow Deployment Node
162
+ class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
163
+ deployment = "example_subworkflow_deployment"
164
+ subworkflow_inputs = {
165
+ "test_input": "test_value",
166
+ }
167
+
168
+ class Outputs(SubworkflowDeploymentNode.Outputs):
169
+ final_output_copy: str
170
+
171
+ # AND we know what the Subworkflow Deployment will respond with
172
+ def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
173
+ execution_id = str(uuid4())
174
+ expected_events: List[WorkflowStreamEvent] = [
175
+ WorkflowExecutionWorkflowResultEvent(
176
+ execution_id=execution_id,
177
+ data=WorkflowResultEvent(
178
+ id=str(uuid4()),
179
+ state="INITIATED",
180
+ ts=datetime.now(),
181
+ ),
182
+ ),
183
+ WorkflowExecutionWorkflowResultEvent(
184
+ execution_id=execution_id,
185
+ data=WorkflowResultEvent(
186
+ id=str(uuid4()),
187
+ state="FULFILLED",
188
+ ts=datetime.now(),
189
+ outputs=[
190
+ WorkflowOutputString(
191
+ id=str(uuid4()),
192
+ name="final-output_copy", # Note the hyphen here
193
+ value="test success",
194
+ )
195
+ ],
196
+ ),
197
+ ),
198
+ ]
199
+ yield from expected_events
200
+
201
+ vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
202
+
203
+ # WHEN we run the node
204
+ node = ExampleSubworkflowDeploymentNode()
205
+ events = list(node.run())
206
+
207
+ # THEN the node should have completed successfully
208
+ assert events[-1].name == "final_output_copy" # Note the underscore here
209
+ assert events[-1].value == "test success"
210
+
211
+
212
+ @pytest.mark.parametrize(
213
+ ["exception", "expected_code", "expected_message"],
214
+ [
215
+ (
216
+ ApiError(status_code=400, body={"detail": "Missing required input variable: 'foo'"}),
217
+ WorkflowErrorCode.INVALID_INPUTS,
218
+ "Missing required input variable: 'foo'",
219
+ ),
220
+ (
221
+ ApiError(status_code=400, body={"message": "Missing required input variable: 'foo'"}),
222
+ WorkflowErrorCode.INVALID_INPUTS,
223
+ "Failed to execute Subworkflow Deployment",
224
+ ),
225
+ (
226
+ ApiError(status_code=400, body="Missing required input variable: 'foo'"),
227
+ WorkflowErrorCode.INTERNAL_ERROR,
228
+ "Failed to execute Subworkflow Deployment",
229
+ ),
230
+ (
231
+ ApiError(status_code=None, body={"detail": "Missing required input variable: 'foo'"}),
232
+ WorkflowErrorCode.INTERNAL_ERROR,
233
+ "Failed to execute Subworkflow Deployment",
234
+ ),
235
+ (
236
+ ApiError(status_code=500, body={"detail": "Missing required input variable: 'foo'"}),
237
+ WorkflowErrorCode.INTERNAL_ERROR,
238
+ "Failed to execute Subworkflow Deployment",
239
+ ),
240
+ ],
241
+ ids=["400", "invalid_dict", "invalid_body", "no_status_code", "500"],
242
+ )
243
+ def test_subworkflow_deployment_node__api_error__invalid_inputs_node_exception(
244
+ vellum_client, exception, expected_code, expected_message
245
+ ):
246
+ # GIVEN a prompt node with an invalid model name
247
+ class MyNode(SubworkflowDeploymentNode):
248
+ deployment = "example_subworkflow_deployment"
249
+ subworkflow_inputs = {
250
+ "not_foo": "bar",
251
+ }
252
+
253
+ # AND the Subworkflow Deployment API call fails
254
+ def _side_effect(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
255
+ if kwargs.get("_mock_condition_to_induce_an_error"):
256
+ yield WorkflowExecutionWorkflowResultEvent(
257
+ execution_id=str(uuid4()),
258
+ data=WorkflowResultEvent(
259
+ id=str(uuid4()),
260
+ state="INITIATED",
261
+ ts=datetime.now(),
262
+ ),
263
+ )
264
+ else:
265
+ raise exception
266
+
267
+ # AND the vellum client execute workflow stream raises a 4xx error
268
+ vellum_client.execute_workflow_stream.side_effect = _side_effect
269
+
270
+ # WHEN the node is run
271
+ with pytest.raises(NodeException) as e:
272
+ list(MyNode().run())
273
+
274
+ # THEN the node raises the correct NodeException
275
+ assert e.value.code == expected_code
276
+ assert e.value.message == expected_message
277
+
278
+
279
+ def test_subworkflow_deployment_node__immediate_api_error__node_exception(vellum_client):
280
+ # GIVEN a prompt node with an invalid model name
281
+ class MyNode(SubworkflowDeploymentNode):
282
+ deployment = "example_subworkflow_deployment"
283
+ subworkflow_inputs = {
284
+ "not_foo": "bar",
285
+ }
286
+
287
+ # AND the vellum client execute workflow stream raises a 4xx error
288
+ vellum_client.execute_workflow_stream.side_effect = ApiError(status_code=404, body={"detail": "Not found"})
289
+
290
+ # WHEN the node is run
291
+ with pytest.raises(NodeException) as e:
292
+ list(MyNode().run())
293
+
294
+ # THEN the node raises the correct NodeException
295
+ assert e.value.code == WorkflowErrorCode.INVALID_INPUTS
296
+ assert e.value.message == "Not found"
@@ -74,5 +74,7 @@ def test_text_prompt_deployment_node__basic(vellum_client):
74
74
  prompt_deployment_name="my-deployment",
75
75
  raw_overrides=OMIT,
76
76
  release_tag="LATEST",
77
- request_options={"additional_body_parameters": {"execution_context": {"parent_context": None}}},
77
+ request_options={
78
+ "additional_body_parameters": {"execution_context": {"parent_context": None, "trace_id": None}}
79
+ },
78
80
  )
@@ -32,7 +32,7 @@ class BaseOutput(Generic[_Delta, _Accumulated]):
32
32
  if value is not undefined and delta is not undefined:
33
33
  raise ValueError("Cannot set both value and delta")
34
34
 
35
- self._name = name
35
+ self._name = name.replace("-", "_") # Convert hyphens to underscores for valid python variable names
36
36
  self._value = value
37
37
  self._delta = delta
38
38
 
@@ -7,7 +7,7 @@ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import undefined
10
- from vellum.workflows.context import execution_context, get_parent_context
10
+ from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context, get_parent_context
11
11
  from vellum.workflows.descriptors.base import BaseDescriptor
12
12
  from vellum.workflows.edges.edge import Edge
13
13
  from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
@@ -29,7 +29,7 @@ from vellum.workflows.events.node import (
29
29
  NodeExecutionRejectedBody,
30
30
  NodeExecutionStreamingBody,
31
31
  )
32
- from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, WorkflowParentContext
32
+ from vellum.workflows.events.types import BaseEvent, NodeParentContext, WorkflowParentContext
33
33
  from vellum.workflows.events.workflow import (
34
34
  WorkflowExecutionFulfilledBody,
35
35
  WorkflowExecutionInitiatedBody,
@@ -75,8 +75,8 @@ class WorkflowRunner(Generic[StateType]):
75
75
  external_inputs: Optional[ExternalInputsArg] = None,
76
76
  cancel_signal: Optional[ThreadingEvent] = None,
77
77
  node_output_mocks: Optional[MockNodeExecutionArg] = None,
78
- parent_context: Optional[ParentContext] = None,
79
78
  max_concurrency: Optional[int] = None,
79
+ init_execution_context: Optional[ExecutionContext] = None,
80
80
  ):
81
81
  if state and external_inputs:
82
82
  raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
@@ -98,6 +98,11 @@ class WorkflowRunner(Generic[StateType]):
98
98
  elif external_inputs:
99
99
  self._initial_state = self.workflow.get_most_recent_state()
100
100
  for descriptor, value in external_inputs.items():
101
+ if not any(isinstance(value, type_) for type_ in descriptor.types):
102
+ raise NodeException(
103
+ f"Invalid external input type for {descriptor.name}",
104
+ code=WorkflowErrorCode.INVALID_INPUTS,
105
+ )
101
106
  self._initial_state.meta.external_inputs[descriptor] = value
102
107
 
103
108
  self._entrypoints = [
@@ -133,7 +138,8 @@ class WorkflowRunner(Generic[StateType]):
133
138
 
134
139
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
135
140
  self._cancel_signal = cancel_signal
136
- self._parent_context = get_parent_context() or parent_context
141
+ self._execution_context = init_execution_context or get_execution_context()
142
+ self._parent_context = self._execution_context.parent_context
137
143
 
138
144
  setattr(
139
145
  self._initial_state,
@@ -196,7 +202,7 @@ class WorkflowRunner(Generic[StateType]):
196
202
  break
197
203
 
198
204
  if not was_mocked:
199
- with execution_context(parent_context=updated_parent_context):
205
+ with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
200
206
  node_run_response = node.run()
201
207
 
202
208
  ports = node.Ports()
@@ -243,7 +249,7 @@ class WorkflowRunner(Generic[StateType]):
243
249
  ),
244
250
  )
245
251
 
246
- with execution_context(parent_context=updated_parent_context):
252
+ with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
247
253
  for output in node_run_response:
248
254
  invoked_ports = output > ports
249
255
  if output.is_initiated:
@@ -346,7 +352,7 @@ class WorkflowRunner(Generic[StateType]):
346
352
  if parent_context is None:
347
353
  parent_context = get_parent_context() or self._parent_context
348
354
 
349
- with execution_context(parent_context=parent_context):
355
+ with execution_context(parent_context=parent_context, trace_id=node.state.meta.trace_id):
350
356
  self._run_work_item(node, span_id)
351
357
 
352
358
  def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
@@ -524,7 +530,7 @@ class WorkflowRunner(Generic[StateType]):
524
530
  for node_cls in self._entrypoints:
525
531
  try:
526
532
  if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
527
- with execution_context(parent_context=current_parent):
533
+ with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
528
534
  self._run_node_if_ready(self._initial_state, node_cls)
529
535
  else:
530
536
  self._concurrency_queue.put((self._initial_state, node_cls, None))
@@ -551,7 +557,7 @@ class WorkflowRunner(Generic[StateType]):
551
557
 
552
558
  self._workflow_event_outer_queue.put(event)
553
559
 
554
- with execution_context(parent_context=current_parent):
560
+ with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
555
561
  rejection_error = self._handle_work_item_event(event)
556
562
 
557
563
  if rejection_error:
@@ -562,7 +568,7 @@ class WorkflowRunner(Generic[StateType]):
562
568
  while event := self._workflow_event_inner_queue.get_nowait():
563
569
  self._workflow_event_outer_queue.put(event)
564
570
 
565
- with execution_context(parent_context=current_parent):
571
+ with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
566
572
  rejection_error = self._handle_work_item_event(event)
567
573
 
568
574
  if rejection_error:
@@ -3,7 +3,7 @@ from queue import Queue
3
3
  from typing import TYPE_CHECKING, Dict, List, Optional, Type
4
4
 
5
5
  from vellum import Vellum
6
- from vellum.workflows.events.types import ParentContext
6
+ from vellum.workflows.context import ExecutionContext, get_execution_context
7
7
  from vellum.workflows.nodes.mocks import MockNodeExecution, MockNodeExecutionArg
8
8
  from vellum.workflows.outputs.base import BaseOutputs
9
9
  from vellum.workflows.references.constant import ConstantValueReference
@@ -18,12 +18,14 @@ class WorkflowContext:
18
18
  self,
19
19
  *,
20
20
  vellum_client: Optional[Vellum] = None,
21
- parent_context: Optional[ParentContext] = None,
21
+ execution_context: Optional[ExecutionContext] = None,
22
22
  ):
23
23
  self._vellum_client = vellum_client
24
- self._parent_context = parent_context
25
24
  self._event_queue: Optional[Queue["WorkflowEvent"]] = None
26
25
  self._node_output_mocks_map: Dict[Type[BaseOutputs], List[MockNodeExecution]] = {}
26
+ self._execution_context = get_execution_context()
27
+ if not self._execution_context.parent_context and execution_context:
28
+ self._execution_context = execution_context
27
29
 
28
30
  @cached_property
29
31
  def vellum_client(self) -> Vellum:
@@ -33,10 +35,8 @@ class WorkflowContext:
33
35
  return create_vellum_client()
34
36
 
35
37
  @cached_property
36
- def parent_context(self) -> Optional[ParentContext]:
37
- if self._parent_context:
38
- return self._parent_context
39
- return None
38
+ def execution_context(self) -> ExecutionContext:
39
+ return self._execution_context
40
40
 
41
41
  @cached_property
42
42
  def node_output_mocks_map(self) -> Dict[Type[BaseOutputs], List[MockNodeExecution]]:
@@ -24,6 +24,7 @@ from typing import (
24
24
  get_args,
25
25
  )
26
26
 
27
+ from vellum.workflows.context import get_execution_context
27
28
  from vellum.workflows.edges import Edge
28
29
  from vellum.workflows.emitters.base import BaseWorkflowEmitter
29
30
  from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
@@ -171,6 +172,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
171
172
  self.resolvers = resolvers or (self.resolvers if hasattr(self, "resolvers") else [])
172
173
  self._context = context or WorkflowContext()
173
174
  self._store = Store()
175
+ self._execution_context = self._context.execution_context
174
176
 
175
177
  self.validate()
176
178
 
@@ -320,8 +322,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
320
322
  external_inputs=external_inputs,
321
323
  cancel_signal=cancel_signal,
322
324
  node_output_mocks=node_output_mocks,
323
- parent_context=self._context.parent_context,
324
325
  max_concurrency=max_concurrency,
326
+ init_execution_context=self._execution_context,
325
327
  ).stream()
326
328
  first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
327
329
  last_event = None
@@ -431,8 +433,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
431
433
  external_inputs=external_inputs,
432
434
  cancel_signal=cancel_signal,
433
435
  node_output_mocks=node_output_mocks,
434
- parent_context=self.context.parent_context,
435
436
  max_concurrency=max_concurrency,
437
+ init_execution_context=self._execution_context,
436
438
  ).stream():
437
439
  if should_yield(self.__class__, event):
438
440
  yield event
@@ -479,10 +481,19 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
479
481
  return self.get_inputs_class()()
480
482
 
481
483
  def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
484
+ execution_context = get_execution_context()
482
485
  return self.get_state_class()(
483
- meta=StateMeta(
484
- parent=self._parent_state,
485
- workflow_inputs=workflow_inputs or self.get_default_inputs(),
486
+ meta=(
487
+ StateMeta(
488
+ parent=self._parent_state,
489
+ workflow_inputs=workflow_inputs or self.get_default_inputs(),
490
+ trace_id=execution_context.trace_id,
491
+ )
492
+ if execution_context and execution_context.trace_id
493
+ else StateMeta(
494
+ parent=self._parent_state,
495
+ workflow_inputs=workflow_inputs or self.get_default_inputs(),
496
+ )
486
497
  )
487
498
  )
488
499