vellum-ai 0.14.37__py3-none-any.whl → 0.14.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. vellum/__init__.py +8 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/reference.md +6272 -0
  4. vellum/client/types/__init__.py +8 -0
  5. vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
  6. vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
  7. vellum/client/types/test_suite_run_exec_config_request.py +4 -0
  8. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
  10. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
  11. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
  12. vellum/plugins/pydantic.py +1 -1
  13. vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
  14. vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
  15. vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
  16. vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
  17. vellum/workflows/events/node.py +2 -1
  18. vellum/workflows/events/types.py +3 -40
  19. vellum/workflows/events/workflow.py +2 -1
  20. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
  21. vellum/workflows/nodes/displayable/conftest.py +2 -6
  22. vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
  23. vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
  24. vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
  25. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +297 -0
  26. vellum/workflows/runner/runner.py +44 -43
  27. vellum/workflows/state/base.py +149 -45
  28. vellum/workflows/types/definition.py +71 -0
  29. vellum/workflows/types/generics.py +34 -1
  30. vellum/workflows/workflows/base.py +20 -3
  31. vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
  32. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/METADATA +1 -1
  33. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/RECORD +37 -25
  34. vellum_ee/workflows/display/vellum.py +0 -5
  35. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/LICENSE +0 -0
  36. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/WHEEL +0 -0
  37. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/entry_points.txt +0 -0
@@ -5,6 +5,8 @@ from typing import Any, Iterator, List
5
5
 
6
6
  from httpx import Response
7
7
 
8
+ from vellum import RejectedExecutePromptEvent
9
+ from vellum.client import ApiError
8
10
  from vellum.client.types.chat_history_input_request import ChatHistoryInputRequest
9
11
  from vellum.client.types.chat_message import ChatMessage
10
12
  from vellum.client.types.chat_message_request import ChatMessageRequest
@@ -15,6 +17,8 @@ from vellum.client.types.json_input_request import JsonInputRequest
15
17
  from vellum.client.types.prompt_output import PromptOutput
16
18
  from vellum.client.types.string_vellum_value import StringVellumValue
17
19
  from vellum.workflows.context import execution_context
20
+ from vellum.workflows.errors import WorkflowErrorCode
21
+ from vellum.workflows.exceptions import NodeException
18
22
  from vellum.workflows.nodes.displayable.prompt_deployment_node.node import PromptDeploymentNode
19
23
 
20
24
 
@@ -194,3 +198,296 @@ def test_prompt_deployment_node__json_output(vellum_client):
194
198
  json_output = outputs[2]
195
199
  assert json_output.name == "json"
196
200
  assert json_output.value == expected_json
201
+
202
+
203
+ def test_prompt_deployment_node__all_fallbacks_fail(vellum_client):
204
+ # GIVEN a Prompt Deployment Node with fallback models
205
+ class TestPromptDeploymentNode(PromptDeploymentNode):
206
+ deployment = "test_deployment"
207
+ prompt_inputs = {"query": "test query"}
208
+ ml_model_fallbacks = ["fallback_model_1", "fallback_model_2"]
209
+
210
+ # AND all models fail with 404 errors
211
+ primary_error = ApiError(
212
+ body={"detail": "Failed to find model 'primary_model'"},
213
+ status_code=404,
214
+ )
215
+ fallback1_error = ApiError(
216
+ body={"detail": "Failed to find model 'fallback_model_1'"},
217
+ status_code=404,
218
+ )
219
+ fallback2_error = ApiError(
220
+ body={"detail": "Failed to find model 'fallback_model_2'"},
221
+ status_code=404,
222
+ )
223
+
224
+ vellum_client.execute_prompt_stream.side_effect = [primary_error, fallback1_error, fallback2_error]
225
+
226
+ # WHEN we run the node
227
+ node = TestPromptDeploymentNode()
228
+
229
+ # THEN an exception should be raised
230
+ with pytest.raises(NodeException) as exc_info:
231
+ list(node.run())
232
+
233
+ # AND the client should have been called three times
234
+ assert vellum_client.execute_prompt_stream.call_count == 3
235
+
236
+ # AND we get the expected error message
237
+ assert (
238
+ exc_info.value.message
239
+ == "Failed to execute prompts with these fallbacks: ['fallback_model_1', 'fallback_model_2']"
240
+ )
241
+
242
+
243
+ def test_prompt_deployment_node__fallback_success(vellum_client):
244
+ # GIVEN a Prompt Deployment Node with fallback models
245
+ class TestPromptDeploymentNode(PromptDeploymentNode):
246
+ deployment = "test_deployment"
247
+ prompt_inputs = {"query": "test query"}
248
+ ml_model_fallbacks = ["fallback_model_1", "fallback_model_2"]
249
+
250
+ # AND the primary model fails with a 404 error
251
+ primary_error = ApiError(
252
+ body={"detail": "Failed to find model 'primary_model'"},
253
+ status_code=404,
254
+ )
255
+
256
+ # AND the first fallback model succeeds
257
+ def generate_successful_stream():
258
+ execution_id = str(uuid4())
259
+ events = [
260
+ InitiatedExecutePromptEvent(execution_id=execution_id),
261
+ FulfilledExecutePromptEvent(
262
+ execution_id=execution_id, outputs=[StringVellumValue(value="Fallback response")]
263
+ ),
264
+ ]
265
+ return iter(events)
266
+
267
+ # Set up the mock to fail on primary but succeed on first fallback
268
+ vellum_client.execute_prompt_stream.side_effect = [primary_error, generate_successful_stream()]
269
+
270
+ # WHEN we run the node
271
+ node = TestPromptDeploymentNode()
272
+ outputs = list(node.run())
273
+
274
+ # THEN the node should complete successfully using the fallback model
275
+ assert len(outputs) > 0
276
+ assert outputs[-1].value == "Fallback response"
277
+
278
+ # AND the client should have been called twice (once for primary, once for fallback)
279
+ assert vellum_client.execute_prompt_stream.call_count == 2
280
+
281
+ # AND the second call should include the fallback model override
282
+ second_call_kwargs = vellum_client.execute_prompt_stream.call_args_list[1][1]
283
+ body_params = second_call_kwargs["request_options"]["additional_body_parameters"]
284
+ assert body_params["overrides"]["ml_model_fallback"] == "fallback_model_1"
285
+
286
+
287
+ def test_prompt_deployment_node__provider_error_with_fallbacks(vellum_client):
288
+ # GIVEN a Prompt Deployment Node with fallback models
289
+ class TestPromptDeploymentNode(PromptDeploymentNode):
290
+ deployment = "test_deployment"
291
+ prompt_inputs = {}
292
+ ml_model_fallbacks = ["gpt-4o", "gemini-1.5-flash-latest"]
293
+
294
+ # AND the primary model starts but then fails with a provider error
295
+ def generate_primary_events():
296
+ execution_id = str(uuid4())
297
+ events = [
298
+ InitiatedExecutePromptEvent(execution_id=execution_id),
299
+ RejectedExecutePromptEvent(
300
+ execution_id=execution_id,
301
+ error={
302
+ "code": "PROVIDER_ERROR",
303
+ "message": "The model provider encountered an error",
304
+ },
305
+ ),
306
+ ]
307
+ return iter(events)
308
+
309
+ # AND the fallback model succeeds
310
+ def generate_fallback_events():
311
+ execution_id = str(uuid4())
312
+ expected_outputs: List[PromptOutput] = [StringVellumValue(value="Fallback response")]
313
+ events = [
314
+ InitiatedExecutePromptEvent(execution_id=execution_id),
315
+ FulfilledExecutePromptEvent(execution_id=execution_id, outputs=expected_outputs),
316
+ ]
317
+ return iter(events)
318
+
319
+ vellum_client.execute_prompt_stream.side_effect = [generate_primary_events(), generate_fallback_events()]
320
+
321
+ # WHEN we run the node
322
+ node = TestPromptDeploymentNode()
323
+ outputs = list(node.run())
324
+
325
+ # THEN the node should complete successfully using the fallback model
326
+ assert len(outputs) > 0
327
+ assert outputs[-1].value == "Fallback response"
328
+
329
+ # AND the client should have been called twice
330
+ assert vellum_client.execute_prompt_stream.call_count == 2
331
+
332
+ # AND the second call should include the fallback model override
333
+ second_call_kwargs = vellum_client.execute_prompt_stream.call_args_list[1][1]
334
+ body_params = second_call_kwargs["request_options"]["additional_body_parameters"]
335
+ assert body_params["overrides"]["ml_model_fallback"] == "gpt-4o"
336
+
337
+
338
+ def test_prompt_deployment_node__multiple_fallbacks_mixed_errors(vellum_client):
339
+ """
340
+ This test case is when the primary model fails with an api error and
341
+ the first fallback fails with a provider error
342
+ """
343
+
344
+ # GIVEN a Prompt Deployment Node with multiple fallback models
345
+ class TestPromptDeploymentNode(PromptDeploymentNode):
346
+ deployment = "test_deployment"
347
+ prompt_inputs = {}
348
+ ml_model_fallbacks = ["gpt-4o", "gemini-1.5-flash-latest"]
349
+
350
+ # AND the primary model fails with an API error
351
+ primary_error = ApiError(
352
+ body={"detail": "Failed to find model 'primary_model'"},
353
+ status_code=404,
354
+ )
355
+
356
+ # AND the first fallback model fails with a provider error
357
+ def generate_fallback1_events():
358
+ execution_id = str(uuid4())
359
+ events = [
360
+ InitiatedExecutePromptEvent(execution_id=execution_id),
361
+ RejectedExecutePromptEvent(
362
+ execution_id=execution_id,
363
+ error={
364
+ "code": "PROVIDER_ERROR",
365
+ "message": "The first fallback provider encountered an error",
366
+ },
367
+ ),
368
+ ]
369
+ return iter(events)
370
+
371
+ # AND the second fallback model succeeds
372
+ def generate_fallback2_events():
373
+ execution_id = str(uuid4())
374
+ expected_outputs: List[PromptOutput] = [StringVellumValue(value="Second fallback response")]
375
+ events = [
376
+ InitiatedExecutePromptEvent(execution_id=execution_id),
377
+ FulfilledExecutePromptEvent(execution_id=execution_id, outputs=expected_outputs),
378
+ ]
379
+ return iter(events)
380
+
381
+ vellum_client.execute_prompt_stream.side_effect = [
382
+ primary_error,
383
+ generate_fallback1_events(),
384
+ generate_fallback2_events(),
385
+ ]
386
+
387
+ # WHEN we run the node
388
+ node = TestPromptDeploymentNode()
389
+ outputs = list(node.run())
390
+
391
+ # THEN the node should complete successfully using the second fallback model
392
+ assert len(outputs) > 0
393
+ assert outputs[-1].value == "Second fallback response"
394
+
395
+ # AND the client should have been called three times
396
+ assert vellum_client.execute_prompt_stream.call_count == 3
397
+
398
+ # AND the calls should include the correct model overrides
399
+ first_fallback_call = vellum_client.execute_prompt_stream.call_args_list[1][1]
400
+ first_fallback_params = first_fallback_call["request_options"]["additional_body_parameters"]
401
+ assert first_fallback_params["overrides"]["ml_model_fallback"] == "gpt-4o"
402
+
403
+ second_fallback_call = vellum_client.execute_prompt_stream.call_args_list[2][1]
404
+ second_fallback_params = second_fallback_call["request_options"]["additional_body_parameters"]
405
+ assert second_fallback_params["overrides"]["ml_model_fallback"] == "gemini-1.5-flash-latest"
406
+
407
+
408
+ def test_prompt_deployment_node_multiple_provider_errors(vellum_client):
409
+ # GIVEN a Prompt Deployment Node with a single fallback model
410
+ class TestPromptDeploymentNode(PromptDeploymentNode):
411
+ deployment = "test_deployment"
412
+ prompt_inputs = {}
413
+ ml_model_fallbacks = ["gpt-4o"]
414
+
415
+ # AND the primary model fails with a provider error
416
+ def generate_primary_events():
417
+ execution_id = str(uuid4())
418
+ events = [
419
+ InitiatedExecutePromptEvent(execution_id=execution_id),
420
+ RejectedExecutePromptEvent(
421
+ execution_id=execution_id,
422
+ error={
423
+ "code": "PROVIDER_ERROR",
424
+ "message": "The primary provider encountered an error",
425
+ },
426
+ ),
427
+ ]
428
+ return iter(events)
429
+
430
+ # AND the fallback model also fails with a provider error
431
+ def generate_fallback1_events():
432
+ execution_id = str(uuid4())
433
+ events = [
434
+ InitiatedExecutePromptEvent(execution_id=execution_id),
435
+ RejectedExecutePromptEvent(
436
+ execution_id=execution_id,
437
+ error={
438
+ "code": "PROVIDER_ERROR",
439
+ "message": "The first fallback provider encountered an error",
440
+ },
441
+ ),
442
+ ]
443
+ return iter(events)
444
+
445
+ vellum_client.execute_prompt_stream.side_effect = [
446
+ generate_primary_events(),
447
+ generate_fallback1_events(),
448
+ ]
449
+
450
+ # WHEN we run the node
451
+ with pytest.raises(NodeException) as exc_info:
452
+ node = TestPromptDeploymentNode()
453
+ list(node.run())
454
+
455
+ # THEN we should get an exception
456
+ assert exc_info.value.message == "Failed to execute prompts with these fallbacks: ['gpt-4o']"
457
+
458
+ # AND the client should have been called two times
459
+ assert vellum_client.execute_prompt_stream.call_count == 2
460
+
461
+ # AND the calls should include the correct model overrides
462
+ first_fallback_call = vellum_client.execute_prompt_stream.call_args_list[1][1]
463
+ first_fallback_params = first_fallback_call["request_options"]["additional_body_parameters"]
464
+ assert first_fallback_params["overrides"]["ml_model_fallback"] == "gpt-4o"
465
+
466
+
467
+ def test_prompt_deployment_node__no_fallbacks(vellum_client):
468
+ # GIVEN a Prompt Deployment Node with no fallback models
469
+ class TestPromptDeploymentNode(PromptDeploymentNode):
470
+ deployment = "test_deployment"
471
+ prompt_inputs = {}
472
+
473
+ # AND the primary model fails with an API error
474
+ primary_error = ApiError(
475
+ body={"detail": "Failed to find model 'primary_model'"},
476
+ status_code=404,
477
+ )
478
+
479
+ vellum_client.execute_prompt_stream.side_effect = primary_error
480
+
481
+ # WHEN we run the node
482
+ node = TestPromptDeploymentNode()
483
+
484
+ # THEN the node should raise an exception
485
+ with pytest.raises(NodeException) as exc_info:
486
+ list(node.run())
487
+
488
+ # AND the exception should contain the original error message
489
+ assert exc_info.value.message == "Failed to find model 'primary_model'"
490
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
491
+
492
+ # AND the client should have been called only once (for the primary model)
493
+ assert vellum_client.execute_prompt_stream.call_count == 1
@@ -146,7 +146,6 @@ class WorkflowRunner(Generic[StateType]):
146
146
  self._active_nodes_by_execution_id: Dict[UUID, ActiveNode[StateType]] = {}
147
147
  self._cancel_signal = cancel_signal
148
148
  self._execution_context = init_execution_context or get_execution_context()
149
- self._parent_context = self._execution_context.parent_context
150
149
 
151
150
  setattr(
152
151
  self._initial_state,
@@ -159,13 +158,13 @@ class WorkflowRunner(Generic[StateType]):
159
158
  def _snapshot_state(self, state: StateType) -> StateType:
160
159
  self._workflow_event_inner_queue.put(
161
160
  WorkflowExecutionSnapshottedEvent(
162
- trace_id=state.meta.trace_id,
161
+ trace_id=self._execution_context.trace_id,
163
162
  span_id=state.meta.span_id,
164
163
  body=WorkflowExecutionSnapshottedBody(
165
164
  workflow_definition=self.workflow.__class__,
166
165
  state=state,
167
166
  ),
168
- parent=self._parent_context,
167
+ parent=self._execution_context.parent_context,
169
168
  )
170
169
  )
171
170
  self.workflow._store.append_state_snapshot(state)
@@ -178,16 +177,16 @@ class WorkflowRunner(Generic[StateType]):
178
177
  return event
179
178
 
180
179
  def _run_work_item(self, node: BaseNode[StateType], span_id: UUID) -> None:
181
- parent_context = get_parent_context()
180
+ execution = get_execution_context()
182
181
  self._workflow_event_inner_queue.put(
183
182
  NodeExecutionInitiatedEvent(
184
- trace_id=node.state.meta.trace_id,
183
+ trace_id=execution.trace_id,
185
184
  span_id=span_id,
186
185
  body=NodeExecutionInitiatedBody(
187
186
  node_definition=node.__class__,
188
187
  inputs=node._inputs,
189
188
  ),
190
- parent=parent_context,
189
+ parent=execution.parent_context,
191
190
  )
192
191
  )
193
192
 
@@ -197,7 +196,7 @@ class WorkflowRunner(Generic[StateType]):
197
196
  updated_parent_context = NodeParentContext(
198
197
  span_id=span_id,
199
198
  node_definition=node.__class__,
200
- parent=parent_context,
199
+ parent=execution.parent_context,
201
200
  )
202
201
  node_run_response: NodeRunResponse
203
202
  was_mocked: Optional[bool] = None
@@ -209,7 +208,7 @@ class WorkflowRunner(Generic[StateType]):
209
208
  break
210
209
 
211
210
  if not was_mocked:
212
- with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
211
+ with execution_context(parent_context=updated_parent_context, trace_id=execution.trace_id):
213
212
  node_run_response = node.run()
214
213
 
215
214
  ports = node.Ports()
@@ -232,7 +231,7 @@ class WorkflowRunner(Generic[StateType]):
232
231
  outputs = node.Outputs()
233
232
 
234
233
  def initiate_node_streaming_output(output: BaseOutput) -> None:
235
- parent_context = get_parent_context()
234
+ execution = get_execution_context()
236
235
  streaming_output_queues[output.name] = Queue()
237
236
  output_descriptor = OutputReference(
238
237
  name=output.name,
@@ -245,18 +244,18 @@ class WorkflowRunner(Generic[StateType]):
245
244
  initiated_ports = initiated_output > ports
246
245
  self._workflow_event_inner_queue.put(
247
246
  NodeExecutionStreamingEvent(
248
- trace_id=node.state.meta.trace_id,
247
+ trace_id=execution.trace_id,
249
248
  span_id=span_id,
250
249
  body=NodeExecutionStreamingBody(
251
250
  node_definition=node.__class__,
252
251
  output=initiated_output,
253
252
  invoked_ports=initiated_ports,
254
253
  ),
255
- parent=parent_context,
254
+ parent=execution.parent_context,
256
255
  ),
257
256
  )
258
257
 
259
- with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
258
+ with execution_context(parent_context=updated_parent_context, trace_id=execution.trace_id):
260
259
  for output in node_run_response:
261
260
  invoked_ports = output > ports
262
261
  if output.is_initiated:
@@ -268,14 +267,14 @@ class WorkflowRunner(Generic[StateType]):
268
267
  streaming_output_queues[output.name].put(output.delta)
269
268
  self._workflow_event_inner_queue.put(
270
269
  NodeExecutionStreamingEvent(
271
- trace_id=node.state.meta.trace_id,
270
+ trace_id=execution.trace_id,
272
271
  span_id=span_id,
273
272
  body=NodeExecutionStreamingBody(
274
273
  node_definition=node.__class__,
275
274
  output=output,
276
275
  invoked_ports=invoked_ports,
277
276
  ),
278
- parent=parent_context,
277
+ parent=execution.parent_context,
279
278
  ),
280
279
  )
281
280
  elif output.is_fulfilled:
@@ -285,14 +284,14 @@ class WorkflowRunner(Generic[StateType]):
285
284
  setattr(outputs, output.name, output.value)
286
285
  self._workflow_event_inner_queue.put(
287
286
  NodeExecutionStreamingEvent(
288
- trace_id=node.state.meta.trace_id,
287
+ trace_id=execution.trace_id,
289
288
  span_id=span_id,
290
289
  body=NodeExecutionStreamingBody(
291
290
  node_definition=node.__class__,
292
291
  output=output,
293
292
  invoked_ports=invoked_ports,
294
293
  ),
295
- parent=parent_context,
294
+ parent=execution.parent_context,
296
295
  )
297
296
  )
298
297
 
@@ -309,7 +308,7 @@ class WorkflowRunner(Generic[StateType]):
309
308
  invoked_ports = ports(outputs, node.state)
310
309
  self._workflow_event_inner_queue.put(
311
310
  NodeExecutionFulfilledEvent(
312
- trace_id=node.state.meta.trace_id,
311
+ trace_id=execution.trace_id,
313
312
  span_id=span_id,
314
313
  body=NodeExecutionFulfilledBody(
315
314
  node_definition=node.__class__,
@@ -317,33 +316,33 @@ class WorkflowRunner(Generic[StateType]):
317
316
  invoked_ports=invoked_ports,
318
317
  mocked=was_mocked,
319
318
  ),
320
- parent=parent_context,
319
+ parent=execution.parent_context,
321
320
  )
322
321
  )
323
322
  except NodeException as e:
324
323
  logger.info(e)
325
324
  self._workflow_event_inner_queue.put(
326
325
  NodeExecutionRejectedEvent(
327
- trace_id=node.state.meta.trace_id,
326
+ trace_id=execution.trace_id,
328
327
  span_id=span_id,
329
328
  body=NodeExecutionRejectedBody(
330
329
  node_definition=node.__class__,
331
330
  error=e.error,
332
331
  ),
333
- parent=parent_context,
332
+ parent=execution.parent_context,
334
333
  )
335
334
  )
336
335
  except WorkflowInitializationException as e:
337
336
  logger.info(e)
338
337
  self._workflow_event_inner_queue.put(
339
338
  NodeExecutionRejectedEvent(
340
- trace_id=node.state.meta.trace_id,
339
+ trace_id=execution.trace_id,
341
340
  span_id=span_id,
342
341
  body=NodeExecutionRejectedBody(
343
342
  node_definition=node.__class__,
344
343
  error=e.error,
345
344
  ),
346
- parent=parent_context,
345
+ parent=execution.parent_context,
347
346
  )
348
347
  )
349
348
  except Exception as e:
@@ -351,7 +350,7 @@ class WorkflowRunner(Generic[StateType]):
351
350
 
352
351
  self._workflow_event_inner_queue.put(
353
352
  NodeExecutionRejectedEvent(
354
- trace_id=node.state.meta.trace_id,
353
+ trace_id=execution.trace_id,
355
354
  span_id=span_id,
356
355
  body=NodeExecutionRejectedBody(
357
356
  node_definition=node.__class__,
@@ -360,17 +359,18 @@ class WorkflowRunner(Generic[StateType]):
360
359
  code=WorkflowErrorCode.INTERNAL_ERROR,
361
360
  ),
362
361
  ),
363
- parent=parent_context,
362
+ parent=execution.parent_context,
364
363
  ),
365
364
  )
366
365
 
367
366
  logger.debug(f"Finished running node: {node.__class__.__name__}")
368
367
 
369
368
  def _context_run_work_item(self, node: BaseNode[StateType], span_id: UUID, parent_context=None) -> None:
370
- if parent_context is None:
371
- parent_context = get_parent_context() or self._parent_context
372
-
373
- with execution_context(parent_context=parent_context, trace_id=node.state.meta.trace_id):
369
+ execution = get_execution_context()
370
+ with execution_context(
371
+ parent_context=parent_context or execution.parent_context,
372
+ trace_id=execution.trace_id,
373
+ ):
374
374
  self._run_work_item(node, span_id)
375
375
 
376
376
  def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
@@ -495,66 +495,67 @@ class WorkflowRunner(Generic[StateType]):
495
495
 
496
496
  def _initiate_workflow_event(self) -> WorkflowExecutionInitiatedEvent:
497
497
  return WorkflowExecutionInitiatedEvent(
498
- trace_id=self._initial_state.meta.trace_id,
498
+ trace_id=self._execution_context.trace_id,
499
499
  span_id=self._initial_state.meta.span_id,
500
500
  body=WorkflowExecutionInitiatedBody(
501
501
  workflow_definition=self.workflow.__class__,
502
502
  inputs=self._initial_state.meta.workflow_inputs,
503
503
  ),
504
- parent=self._parent_context,
504
+ parent=self._execution_context.parent_context,
505
505
  )
506
506
 
507
507
  def _stream_workflow_event(self, output: BaseOutput) -> WorkflowExecutionStreamingEvent:
508
508
  return WorkflowExecutionStreamingEvent(
509
- trace_id=self._initial_state.meta.trace_id,
509
+ trace_id=self._execution_context.trace_id,
510
510
  span_id=self._initial_state.meta.span_id,
511
511
  body=WorkflowExecutionStreamingBody(
512
512
  workflow_definition=self.workflow.__class__,
513
513
  output=output,
514
514
  ),
515
- parent=self._parent_context,
515
+ parent=self._execution_context.parent_context,
516
516
  )
517
517
 
518
518
  def _fulfill_workflow_event(self, outputs: OutputsType) -> WorkflowExecutionFulfilledEvent:
519
519
  return WorkflowExecutionFulfilledEvent(
520
- trace_id=self._initial_state.meta.trace_id,
520
+ trace_id=self._execution_context.trace_id,
521
521
  span_id=self._initial_state.meta.span_id,
522
522
  body=WorkflowExecutionFulfilledBody(
523
523
  workflow_definition=self.workflow.__class__,
524
524
  outputs=outputs,
525
525
  ),
526
- parent=self._parent_context,
526
+ parent=self._execution_context.parent_context,
527
527
  )
528
528
 
529
529
  def _reject_workflow_event(self, error: WorkflowError) -> WorkflowExecutionRejectedEvent:
530
530
  return WorkflowExecutionRejectedEvent(
531
- trace_id=self._initial_state.meta.trace_id,
531
+ trace_id=self._execution_context.trace_id,
532
532
  span_id=self._initial_state.meta.span_id,
533
533
  body=WorkflowExecutionRejectedBody(
534
534
  workflow_definition=self.workflow.__class__,
535
535
  error=error,
536
536
  ),
537
- parent=self._parent_context,
537
+ parent=self._execution_context.parent_context,
538
538
  )
539
539
 
540
540
  def _resume_workflow_event(self) -> WorkflowExecutionResumedEvent:
541
541
  return WorkflowExecutionResumedEvent(
542
- trace_id=self._initial_state.meta.trace_id,
542
+ trace_id=self._execution_context.trace_id,
543
543
  span_id=self._initial_state.meta.span_id,
544
544
  body=WorkflowExecutionResumedBody(
545
545
  workflow_definition=self.workflow.__class__,
546
546
  ),
547
+ parent=self._execution_context.parent_context,
547
548
  )
548
549
 
549
550
  def _pause_workflow_event(self, external_inputs: Iterable[ExternalInputReference]) -> WorkflowExecutionPausedEvent:
550
551
  return WorkflowExecutionPausedEvent(
551
- trace_id=self._initial_state.meta.trace_id,
552
+ trace_id=self._execution_context.trace_id,
552
553
  span_id=self._initial_state.meta.span_id,
553
554
  body=WorkflowExecutionPausedBody(
554
555
  workflow_definition=self.workflow.__class__,
555
556
  external_inputs=external_inputs,
556
557
  ),
557
- parent=self._parent_context,
558
+ parent=self._execution_context.parent_context,
558
559
  )
559
560
 
560
561
  def _stream(self) -> None:
@@ -564,13 +565,13 @@ class WorkflowRunner(Generic[StateType]):
564
565
  current_parent = WorkflowParentContext(
565
566
  span_id=self._initial_state.meta.span_id,
566
567
  workflow_definition=self.workflow.__class__,
567
- parent=self._parent_context,
568
+ parent=self._execution_context.parent_context,
568
569
  type="WORKFLOW",
569
570
  )
570
571
  for node_cls in self._entrypoints:
571
572
  try:
572
573
  if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
573
- with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
574
+ with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
574
575
  self._run_node_if_ready(self._initial_state, node_cls)
575
576
  else:
576
577
  self._concurrency_queue.put((self._initial_state, node_cls, None))
@@ -600,7 +601,7 @@ class WorkflowRunner(Generic[StateType]):
600
601
 
601
602
  self._workflow_event_outer_queue.put(event)
602
603
 
603
- with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
604
+ with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
604
605
  rejection_error = self._handle_work_item_event(event)
605
606
 
606
607
  if rejection_error:
@@ -611,7 +612,7 @@ class WorkflowRunner(Generic[StateType]):
611
612
  while event := self._workflow_event_inner_queue.get_nowait():
612
613
  self._workflow_event_outer_queue.put(event)
613
614
 
614
- with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
615
+ with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
615
616
  rejection_error = self._handle_work_item_event(event)
616
617
 
617
618
  if rejection_error: