vellum-ai 0.14.37__py3-none-any.whl → 0.14.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +8 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +6272 -0
- vellum/client/types/__init__.py +8 -0
- vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/test_suite_run_exec_config_request.py +4 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
- vellum/plugins/pydantic.py +1 -1
- vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
- vellum/workflows/events/node.py +2 -1
- vellum/workflows/events/types.py +3 -40
- vellum/workflows/events/workflow.py +2 -1
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
- vellum/workflows/nodes/displayable/conftest.py +2 -6
- vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
- vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +297 -0
- vellum/workflows/runner/runner.py +44 -43
- vellum/workflows/state/base.py +149 -45
- vellum/workflows/types/definition.py +71 -0
- vellum/workflows/types/generics.py +34 -1
- vellum/workflows/workflows/base.py +20 -3
- vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/RECORD +37 -25
- vellum_ee/workflows/display/vellum.py +0 -5
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/entry_points.txt +0 -0
@@ -5,6 +5,8 @@ from typing import Any, Iterator, List
|
|
5
5
|
|
6
6
|
from httpx import Response
|
7
7
|
|
8
|
+
from vellum import RejectedExecutePromptEvent
|
9
|
+
from vellum.client import ApiError
|
8
10
|
from vellum.client.types.chat_history_input_request import ChatHistoryInputRequest
|
9
11
|
from vellum.client.types.chat_message import ChatMessage
|
10
12
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
@@ -15,6 +17,8 @@ from vellum.client.types.json_input_request import JsonInputRequest
|
|
15
17
|
from vellum.client.types.prompt_output import PromptOutput
|
16
18
|
from vellum.client.types.string_vellum_value import StringVellumValue
|
17
19
|
from vellum.workflows.context import execution_context
|
20
|
+
from vellum.workflows.errors import WorkflowErrorCode
|
21
|
+
from vellum.workflows.exceptions import NodeException
|
18
22
|
from vellum.workflows.nodes.displayable.prompt_deployment_node.node import PromptDeploymentNode
|
19
23
|
|
20
24
|
|
@@ -194,3 +198,296 @@ def test_prompt_deployment_node__json_output(vellum_client):
|
|
194
198
|
json_output = outputs[2]
|
195
199
|
assert json_output.name == "json"
|
196
200
|
assert json_output.value == expected_json
|
201
|
+
|
202
|
+
|
203
|
+
def test_prompt_deployment_node__all_fallbacks_fail(vellum_client):
|
204
|
+
# GIVEN a Prompt Deployment Node with fallback models
|
205
|
+
class TestPromptDeploymentNode(PromptDeploymentNode):
|
206
|
+
deployment = "test_deployment"
|
207
|
+
prompt_inputs = {"query": "test query"}
|
208
|
+
ml_model_fallbacks = ["fallback_model_1", "fallback_model_2"]
|
209
|
+
|
210
|
+
# AND all models fail with 404 errors
|
211
|
+
primary_error = ApiError(
|
212
|
+
body={"detail": "Failed to find model 'primary_model'"},
|
213
|
+
status_code=404,
|
214
|
+
)
|
215
|
+
fallback1_error = ApiError(
|
216
|
+
body={"detail": "Failed to find model 'fallback_model_1'"},
|
217
|
+
status_code=404,
|
218
|
+
)
|
219
|
+
fallback2_error = ApiError(
|
220
|
+
body={"detail": "Failed to find model 'fallback_model_2'"},
|
221
|
+
status_code=404,
|
222
|
+
)
|
223
|
+
|
224
|
+
vellum_client.execute_prompt_stream.side_effect = [primary_error, fallback1_error, fallback2_error]
|
225
|
+
|
226
|
+
# WHEN we run the node
|
227
|
+
node = TestPromptDeploymentNode()
|
228
|
+
|
229
|
+
# THEN an exception should be raised
|
230
|
+
with pytest.raises(NodeException) as exc_info:
|
231
|
+
list(node.run())
|
232
|
+
|
233
|
+
# AND the client should have been called three times
|
234
|
+
assert vellum_client.execute_prompt_stream.call_count == 3
|
235
|
+
|
236
|
+
# AND we get the expected error message
|
237
|
+
assert (
|
238
|
+
exc_info.value.message
|
239
|
+
== "Failed to execute prompts with these fallbacks: ['fallback_model_1', 'fallback_model_2']"
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
def test_prompt_deployment_node__fallback_success(vellum_client):
|
244
|
+
# GIVEN a Prompt Deployment Node with fallback models
|
245
|
+
class TestPromptDeploymentNode(PromptDeploymentNode):
|
246
|
+
deployment = "test_deployment"
|
247
|
+
prompt_inputs = {"query": "test query"}
|
248
|
+
ml_model_fallbacks = ["fallback_model_1", "fallback_model_2"]
|
249
|
+
|
250
|
+
# AND the primary model fails with a 404 error
|
251
|
+
primary_error = ApiError(
|
252
|
+
body={"detail": "Failed to find model 'primary_model'"},
|
253
|
+
status_code=404,
|
254
|
+
)
|
255
|
+
|
256
|
+
# AND the first fallback model succeeds
|
257
|
+
def generate_successful_stream():
|
258
|
+
execution_id = str(uuid4())
|
259
|
+
events = [
|
260
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
261
|
+
FulfilledExecutePromptEvent(
|
262
|
+
execution_id=execution_id, outputs=[StringVellumValue(value="Fallback response")]
|
263
|
+
),
|
264
|
+
]
|
265
|
+
return iter(events)
|
266
|
+
|
267
|
+
# Set up the mock to fail on primary but succeed on first fallback
|
268
|
+
vellum_client.execute_prompt_stream.side_effect = [primary_error, generate_successful_stream()]
|
269
|
+
|
270
|
+
# WHEN we run the node
|
271
|
+
node = TestPromptDeploymentNode()
|
272
|
+
outputs = list(node.run())
|
273
|
+
|
274
|
+
# THEN the node should complete successfully using the fallback model
|
275
|
+
assert len(outputs) > 0
|
276
|
+
assert outputs[-1].value == "Fallback response"
|
277
|
+
|
278
|
+
# AND the client should have been called twice (once for primary, once for fallback)
|
279
|
+
assert vellum_client.execute_prompt_stream.call_count == 2
|
280
|
+
|
281
|
+
# AND the second call should include the fallback model override
|
282
|
+
second_call_kwargs = vellum_client.execute_prompt_stream.call_args_list[1][1]
|
283
|
+
body_params = second_call_kwargs["request_options"]["additional_body_parameters"]
|
284
|
+
assert body_params["overrides"]["ml_model_fallback"] == "fallback_model_1"
|
285
|
+
|
286
|
+
|
287
|
+
def test_prompt_deployment_node__provider_error_with_fallbacks(vellum_client):
|
288
|
+
# GIVEN a Prompt Deployment Node with fallback models
|
289
|
+
class TestPromptDeploymentNode(PromptDeploymentNode):
|
290
|
+
deployment = "test_deployment"
|
291
|
+
prompt_inputs = {}
|
292
|
+
ml_model_fallbacks = ["gpt-4o", "gemini-1.5-flash-latest"]
|
293
|
+
|
294
|
+
# AND the primary model starts but then fails with a provider error
|
295
|
+
def generate_primary_events():
|
296
|
+
execution_id = str(uuid4())
|
297
|
+
events = [
|
298
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
299
|
+
RejectedExecutePromptEvent(
|
300
|
+
execution_id=execution_id,
|
301
|
+
error={
|
302
|
+
"code": "PROVIDER_ERROR",
|
303
|
+
"message": "The model provider encountered an error",
|
304
|
+
},
|
305
|
+
),
|
306
|
+
]
|
307
|
+
return iter(events)
|
308
|
+
|
309
|
+
# AND the fallback model succeeds
|
310
|
+
def generate_fallback_events():
|
311
|
+
execution_id = str(uuid4())
|
312
|
+
expected_outputs: List[PromptOutput] = [StringVellumValue(value="Fallback response")]
|
313
|
+
events = [
|
314
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
315
|
+
FulfilledExecutePromptEvent(execution_id=execution_id, outputs=expected_outputs),
|
316
|
+
]
|
317
|
+
return iter(events)
|
318
|
+
|
319
|
+
vellum_client.execute_prompt_stream.side_effect = [generate_primary_events(), generate_fallback_events()]
|
320
|
+
|
321
|
+
# WHEN we run the node
|
322
|
+
node = TestPromptDeploymentNode()
|
323
|
+
outputs = list(node.run())
|
324
|
+
|
325
|
+
# THEN the node should complete successfully using the fallback model
|
326
|
+
assert len(outputs) > 0
|
327
|
+
assert outputs[-1].value == "Fallback response"
|
328
|
+
|
329
|
+
# AND the client should have been called twice
|
330
|
+
assert vellum_client.execute_prompt_stream.call_count == 2
|
331
|
+
|
332
|
+
# AND the second call should include the fallback model override
|
333
|
+
second_call_kwargs = vellum_client.execute_prompt_stream.call_args_list[1][1]
|
334
|
+
body_params = second_call_kwargs["request_options"]["additional_body_parameters"]
|
335
|
+
assert body_params["overrides"]["ml_model_fallback"] == "gpt-4o"
|
336
|
+
|
337
|
+
|
338
|
+
def test_prompt_deployment_node__multiple_fallbacks_mixed_errors(vellum_client):
|
339
|
+
"""
|
340
|
+
This test case is when the primary model fails with an api error and
|
341
|
+
the first fallback fails with a provider error
|
342
|
+
"""
|
343
|
+
|
344
|
+
# GIVEN a Prompt Deployment Node with multiple fallback models
|
345
|
+
class TestPromptDeploymentNode(PromptDeploymentNode):
|
346
|
+
deployment = "test_deployment"
|
347
|
+
prompt_inputs = {}
|
348
|
+
ml_model_fallbacks = ["gpt-4o", "gemini-1.5-flash-latest"]
|
349
|
+
|
350
|
+
# AND the primary model fails with an API error
|
351
|
+
primary_error = ApiError(
|
352
|
+
body={"detail": "Failed to find model 'primary_model'"},
|
353
|
+
status_code=404,
|
354
|
+
)
|
355
|
+
|
356
|
+
# AND the first fallback model fails with a provider error
|
357
|
+
def generate_fallback1_events():
|
358
|
+
execution_id = str(uuid4())
|
359
|
+
events = [
|
360
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
361
|
+
RejectedExecutePromptEvent(
|
362
|
+
execution_id=execution_id,
|
363
|
+
error={
|
364
|
+
"code": "PROVIDER_ERROR",
|
365
|
+
"message": "The first fallback provider encountered an error",
|
366
|
+
},
|
367
|
+
),
|
368
|
+
]
|
369
|
+
return iter(events)
|
370
|
+
|
371
|
+
# AND the second fallback model succeeds
|
372
|
+
def generate_fallback2_events():
|
373
|
+
execution_id = str(uuid4())
|
374
|
+
expected_outputs: List[PromptOutput] = [StringVellumValue(value="Second fallback response")]
|
375
|
+
events = [
|
376
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
377
|
+
FulfilledExecutePromptEvent(execution_id=execution_id, outputs=expected_outputs),
|
378
|
+
]
|
379
|
+
return iter(events)
|
380
|
+
|
381
|
+
vellum_client.execute_prompt_stream.side_effect = [
|
382
|
+
primary_error,
|
383
|
+
generate_fallback1_events(),
|
384
|
+
generate_fallback2_events(),
|
385
|
+
]
|
386
|
+
|
387
|
+
# WHEN we run the node
|
388
|
+
node = TestPromptDeploymentNode()
|
389
|
+
outputs = list(node.run())
|
390
|
+
|
391
|
+
# THEN the node should complete successfully using the second fallback model
|
392
|
+
assert len(outputs) > 0
|
393
|
+
assert outputs[-1].value == "Second fallback response"
|
394
|
+
|
395
|
+
# AND the client should have been called three times
|
396
|
+
assert vellum_client.execute_prompt_stream.call_count == 3
|
397
|
+
|
398
|
+
# AND the calls should include the correct model overrides
|
399
|
+
first_fallback_call = vellum_client.execute_prompt_stream.call_args_list[1][1]
|
400
|
+
first_fallback_params = first_fallback_call["request_options"]["additional_body_parameters"]
|
401
|
+
assert first_fallback_params["overrides"]["ml_model_fallback"] == "gpt-4o"
|
402
|
+
|
403
|
+
second_fallback_call = vellum_client.execute_prompt_stream.call_args_list[2][1]
|
404
|
+
second_fallback_params = second_fallback_call["request_options"]["additional_body_parameters"]
|
405
|
+
assert second_fallback_params["overrides"]["ml_model_fallback"] == "gemini-1.5-flash-latest"
|
406
|
+
|
407
|
+
|
408
|
+
def test_prompt_deployment_node_multiple_provider_errors(vellum_client):
|
409
|
+
# GIVEN a Prompt Deployment Node with a single fallback model
|
410
|
+
class TestPromptDeploymentNode(PromptDeploymentNode):
|
411
|
+
deployment = "test_deployment"
|
412
|
+
prompt_inputs = {}
|
413
|
+
ml_model_fallbacks = ["gpt-4o"]
|
414
|
+
|
415
|
+
# AND the primary model fails with a provider error
|
416
|
+
def generate_primary_events():
|
417
|
+
execution_id = str(uuid4())
|
418
|
+
events = [
|
419
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
420
|
+
RejectedExecutePromptEvent(
|
421
|
+
execution_id=execution_id,
|
422
|
+
error={
|
423
|
+
"code": "PROVIDER_ERROR",
|
424
|
+
"message": "The primary provider encountered an error",
|
425
|
+
},
|
426
|
+
),
|
427
|
+
]
|
428
|
+
return iter(events)
|
429
|
+
|
430
|
+
# AND the fallback model also fails with a provider error
|
431
|
+
def generate_fallback1_events():
|
432
|
+
execution_id = str(uuid4())
|
433
|
+
events = [
|
434
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
435
|
+
RejectedExecutePromptEvent(
|
436
|
+
execution_id=execution_id,
|
437
|
+
error={
|
438
|
+
"code": "PROVIDER_ERROR",
|
439
|
+
"message": "The first fallback provider encountered an error",
|
440
|
+
},
|
441
|
+
),
|
442
|
+
]
|
443
|
+
return iter(events)
|
444
|
+
|
445
|
+
vellum_client.execute_prompt_stream.side_effect = [
|
446
|
+
generate_primary_events(),
|
447
|
+
generate_fallback1_events(),
|
448
|
+
]
|
449
|
+
|
450
|
+
# WHEN we run the node
|
451
|
+
with pytest.raises(NodeException) as exc_info:
|
452
|
+
node = TestPromptDeploymentNode()
|
453
|
+
list(node.run())
|
454
|
+
|
455
|
+
# THEN we should get an exception
|
456
|
+
assert exc_info.value.message == "Failed to execute prompts with these fallbacks: ['gpt-4o']"
|
457
|
+
|
458
|
+
# AND the client should have been called two times
|
459
|
+
assert vellum_client.execute_prompt_stream.call_count == 2
|
460
|
+
|
461
|
+
# AND the calls should include the correct model overrides
|
462
|
+
first_fallback_call = vellum_client.execute_prompt_stream.call_args_list[1][1]
|
463
|
+
first_fallback_params = first_fallback_call["request_options"]["additional_body_parameters"]
|
464
|
+
assert first_fallback_params["overrides"]["ml_model_fallback"] == "gpt-4o"
|
465
|
+
|
466
|
+
|
467
|
+
def test_prompt_deployment_node__no_fallbacks(vellum_client):
|
468
|
+
# GIVEN a Prompt Deployment Node with no fallback models
|
469
|
+
class TestPromptDeploymentNode(PromptDeploymentNode):
|
470
|
+
deployment = "test_deployment"
|
471
|
+
prompt_inputs = {}
|
472
|
+
|
473
|
+
# AND the primary model fails with an API error
|
474
|
+
primary_error = ApiError(
|
475
|
+
body={"detail": "Failed to find model 'primary_model'"},
|
476
|
+
status_code=404,
|
477
|
+
)
|
478
|
+
|
479
|
+
vellum_client.execute_prompt_stream.side_effect = primary_error
|
480
|
+
|
481
|
+
# WHEN we run the node
|
482
|
+
node = TestPromptDeploymentNode()
|
483
|
+
|
484
|
+
# THEN the node should raise an exception
|
485
|
+
with pytest.raises(NodeException) as exc_info:
|
486
|
+
list(node.run())
|
487
|
+
|
488
|
+
# AND the exception should contain the original error message
|
489
|
+
assert exc_info.value.message == "Failed to find model 'primary_model'"
|
490
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_INPUTS
|
491
|
+
|
492
|
+
# AND the client should have been called only once (for the primary model)
|
493
|
+
assert vellum_client.execute_prompt_stream.call_count == 1
|
@@ -146,7 +146,6 @@ class WorkflowRunner(Generic[StateType]):
|
|
146
146
|
self._active_nodes_by_execution_id: Dict[UUID, ActiveNode[StateType]] = {}
|
147
147
|
self._cancel_signal = cancel_signal
|
148
148
|
self._execution_context = init_execution_context or get_execution_context()
|
149
|
-
self._parent_context = self._execution_context.parent_context
|
150
149
|
|
151
150
|
setattr(
|
152
151
|
self._initial_state,
|
@@ -159,13 +158,13 @@ class WorkflowRunner(Generic[StateType]):
|
|
159
158
|
def _snapshot_state(self, state: StateType) -> StateType:
|
160
159
|
self._workflow_event_inner_queue.put(
|
161
160
|
WorkflowExecutionSnapshottedEvent(
|
162
|
-
trace_id=
|
161
|
+
trace_id=self._execution_context.trace_id,
|
163
162
|
span_id=state.meta.span_id,
|
164
163
|
body=WorkflowExecutionSnapshottedBody(
|
165
164
|
workflow_definition=self.workflow.__class__,
|
166
165
|
state=state,
|
167
166
|
),
|
168
|
-
parent=self.
|
167
|
+
parent=self._execution_context.parent_context,
|
169
168
|
)
|
170
169
|
)
|
171
170
|
self.workflow._store.append_state_snapshot(state)
|
@@ -178,16 +177,16 @@ class WorkflowRunner(Generic[StateType]):
|
|
178
177
|
return event
|
179
178
|
|
180
179
|
def _run_work_item(self, node: BaseNode[StateType], span_id: UUID) -> None:
|
181
|
-
|
180
|
+
execution = get_execution_context()
|
182
181
|
self._workflow_event_inner_queue.put(
|
183
182
|
NodeExecutionInitiatedEvent(
|
184
|
-
trace_id=
|
183
|
+
trace_id=execution.trace_id,
|
185
184
|
span_id=span_id,
|
186
185
|
body=NodeExecutionInitiatedBody(
|
187
186
|
node_definition=node.__class__,
|
188
187
|
inputs=node._inputs,
|
189
188
|
),
|
190
|
-
parent=parent_context,
|
189
|
+
parent=execution.parent_context,
|
191
190
|
)
|
192
191
|
)
|
193
192
|
|
@@ -197,7 +196,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
197
196
|
updated_parent_context = NodeParentContext(
|
198
197
|
span_id=span_id,
|
199
198
|
node_definition=node.__class__,
|
200
|
-
parent=parent_context,
|
199
|
+
parent=execution.parent_context,
|
201
200
|
)
|
202
201
|
node_run_response: NodeRunResponse
|
203
202
|
was_mocked: Optional[bool] = None
|
@@ -209,7 +208,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
209
208
|
break
|
210
209
|
|
211
210
|
if not was_mocked:
|
212
|
-
with execution_context(parent_context=updated_parent_context, trace_id=
|
211
|
+
with execution_context(parent_context=updated_parent_context, trace_id=execution.trace_id):
|
213
212
|
node_run_response = node.run()
|
214
213
|
|
215
214
|
ports = node.Ports()
|
@@ -232,7 +231,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
232
231
|
outputs = node.Outputs()
|
233
232
|
|
234
233
|
def initiate_node_streaming_output(output: BaseOutput) -> None:
|
235
|
-
|
234
|
+
execution = get_execution_context()
|
236
235
|
streaming_output_queues[output.name] = Queue()
|
237
236
|
output_descriptor = OutputReference(
|
238
237
|
name=output.name,
|
@@ -245,18 +244,18 @@ class WorkflowRunner(Generic[StateType]):
|
|
245
244
|
initiated_ports = initiated_output > ports
|
246
245
|
self._workflow_event_inner_queue.put(
|
247
246
|
NodeExecutionStreamingEvent(
|
248
|
-
trace_id=
|
247
|
+
trace_id=execution.trace_id,
|
249
248
|
span_id=span_id,
|
250
249
|
body=NodeExecutionStreamingBody(
|
251
250
|
node_definition=node.__class__,
|
252
251
|
output=initiated_output,
|
253
252
|
invoked_ports=initiated_ports,
|
254
253
|
),
|
255
|
-
parent=parent_context,
|
254
|
+
parent=execution.parent_context,
|
256
255
|
),
|
257
256
|
)
|
258
257
|
|
259
|
-
with execution_context(parent_context=updated_parent_context, trace_id=
|
258
|
+
with execution_context(parent_context=updated_parent_context, trace_id=execution.trace_id):
|
260
259
|
for output in node_run_response:
|
261
260
|
invoked_ports = output > ports
|
262
261
|
if output.is_initiated:
|
@@ -268,14 +267,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
268
267
|
streaming_output_queues[output.name].put(output.delta)
|
269
268
|
self._workflow_event_inner_queue.put(
|
270
269
|
NodeExecutionStreamingEvent(
|
271
|
-
trace_id=
|
270
|
+
trace_id=execution.trace_id,
|
272
271
|
span_id=span_id,
|
273
272
|
body=NodeExecutionStreamingBody(
|
274
273
|
node_definition=node.__class__,
|
275
274
|
output=output,
|
276
275
|
invoked_ports=invoked_ports,
|
277
276
|
),
|
278
|
-
parent=parent_context,
|
277
|
+
parent=execution.parent_context,
|
279
278
|
),
|
280
279
|
)
|
281
280
|
elif output.is_fulfilled:
|
@@ -285,14 +284,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
285
284
|
setattr(outputs, output.name, output.value)
|
286
285
|
self._workflow_event_inner_queue.put(
|
287
286
|
NodeExecutionStreamingEvent(
|
288
|
-
trace_id=
|
287
|
+
trace_id=execution.trace_id,
|
289
288
|
span_id=span_id,
|
290
289
|
body=NodeExecutionStreamingBody(
|
291
290
|
node_definition=node.__class__,
|
292
291
|
output=output,
|
293
292
|
invoked_ports=invoked_ports,
|
294
293
|
),
|
295
|
-
parent=parent_context,
|
294
|
+
parent=execution.parent_context,
|
296
295
|
)
|
297
296
|
)
|
298
297
|
|
@@ -309,7 +308,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
309
308
|
invoked_ports = ports(outputs, node.state)
|
310
309
|
self._workflow_event_inner_queue.put(
|
311
310
|
NodeExecutionFulfilledEvent(
|
312
|
-
trace_id=
|
311
|
+
trace_id=execution.trace_id,
|
313
312
|
span_id=span_id,
|
314
313
|
body=NodeExecutionFulfilledBody(
|
315
314
|
node_definition=node.__class__,
|
@@ -317,33 +316,33 @@ class WorkflowRunner(Generic[StateType]):
|
|
317
316
|
invoked_ports=invoked_ports,
|
318
317
|
mocked=was_mocked,
|
319
318
|
),
|
320
|
-
parent=parent_context,
|
319
|
+
parent=execution.parent_context,
|
321
320
|
)
|
322
321
|
)
|
323
322
|
except NodeException as e:
|
324
323
|
logger.info(e)
|
325
324
|
self._workflow_event_inner_queue.put(
|
326
325
|
NodeExecutionRejectedEvent(
|
327
|
-
trace_id=
|
326
|
+
trace_id=execution.trace_id,
|
328
327
|
span_id=span_id,
|
329
328
|
body=NodeExecutionRejectedBody(
|
330
329
|
node_definition=node.__class__,
|
331
330
|
error=e.error,
|
332
331
|
),
|
333
|
-
parent=parent_context,
|
332
|
+
parent=execution.parent_context,
|
334
333
|
)
|
335
334
|
)
|
336
335
|
except WorkflowInitializationException as e:
|
337
336
|
logger.info(e)
|
338
337
|
self._workflow_event_inner_queue.put(
|
339
338
|
NodeExecutionRejectedEvent(
|
340
|
-
trace_id=
|
339
|
+
trace_id=execution.trace_id,
|
341
340
|
span_id=span_id,
|
342
341
|
body=NodeExecutionRejectedBody(
|
343
342
|
node_definition=node.__class__,
|
344
343
|
error=e.error,
|
345
344
|
),
|
346
|
-
parent=parent_context,
|
345
|
+
parent=execution.parent_context,
|
347
346
|
)
|
348
347
|
)
|
349
348
|
except Exception as e:
|
@@ -351,7 +350,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
351
350
|
|
352
351
|
self._workflow_event_inner_queue.put(
|
353
352
|
NodeExecutionRejectedEvent(
|
354
|
-
trace_id=
|
353
|
+
trace_id=execution.trace_id,
|
355
354
|
span_id=span_id,
|
356
355
|
body=NodeExecutionRejectedBody(
|
357
356
|
node_definition=node.__class__,
|
@@ -360,17 +359,18 @@ class WorkflowRunner(Generic[StateType]):
|
|
360
359
|
code=WorkflowErrorCode.INTERNAL_ERROR,
|
361
360
|
),
|
362
361
|
),
|
363
|
-
parent=parent_context,
|
362
|
+
parent=execution.parent_context,
|
364
363
|
),
|
365
364
|
)
|
366
365
|
|
367
366
|
logger.debug(f"Finished running node: {node.__class__.__name__}")
|
368
367
|
|
369
368
|
def _context_run_work_item(self, node: BaseNode[StateType], span_id: UUID, parent_context=None) -> None:
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
369
|
+
execution = get_execution_context()
|
370
|
+
with execution_context(
|
371
|
+
parent_context=parent_context or execution.parent_context,
|
372
|
+
trace_id=execution.trace_id,
|
373
|
+
):
|
374
374
|
self._run_work_item(node, span_id)
|
375
375
|
|
376
376
|
def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
|
@@ -495,66 +495,67 @@ class WorkflowRunner(Generic[StateType]):
|
|
495
495
|
|
496
496
|
def _initiate_workflow_event(self) -> WorkflowExecutionInitiatedEvent:
|
497
497
|
return WorkflowExecutionInitiatedEvent(
|
498
|
-
trace_id=self.
|
498
|
+
trace_id=self._execution_context.trace_id,
|
499
499
|
span_id=self._initial_state.meta.span_id,
|
500
500
|
body=WorkflowExecutionInitiatedBody(
|
501
501
|
workflow_definition=self.workflow.__class__,
|
502
502
|
inputs=self._initial_state.meta.workflow_inputs,
|
503
503
|
),
|
504
|
-
parent=self.
|
504
|
+
parent=self._execution_context.parent_context,
|
505
505
|
)
|
506
506
|
|
507
507
|
def _stream_workflow_event(self, output: BaseOutput) -> WorkflowExecutionStreamingEvent:
|
508
508
|
return WorkflowExecutionStreamingEvent(
|
509
|
-
trace_id=self.
|
509
|
+
trace_id=self._execution_context.trace_id,
|
510
510
|
span_id=self._initial_state.meta.span_id,
|
511
511
|
body=WorkflowExecutionStreamingBody(
|
512
512
|
workflow_definition=self.workflow.__class__,
|
513
513
|
output=output,
|
514
514
|
),
|
515
|
-
parent=self.
|
515
|
+
parent=self._execution_context.parent_context,
|
516
516
|
)
|
517
517
|
|
518
518
|
def _fulfill_workflow_event(self, outputs: OutputsType) -> WorkflowExecutionFulfilledEvent:
|
519
519
|
return WorkflowExecutionFulfilledEvent(
|
520
|
-
trace_id=self.
|
520
|
+
trace_id=self._execution_context.trace_id,
|
521
521
|
span_id=self._initial_state.meta.span_id,
|
522
522
|
body=WorkflowExecutionFulfilledBody(
|
523
523
|
workflow_definition=self.workflow.__class__,
|
524
524
|
outputs=outputs,
|
525
525
|
),
|
526
|
-
parent=self.
|
526
|
+
parent=self._execution_context.parent_context,
|
527
527
|
)
|
528
528
|
|
529
529
|
def _reject_workflow_event(self, error: WorkflowError) -> WorkflowExecutionRejectedEvent:
|
530
530
|
return WorkflowExecutionRejectedEvent(
|
531
|
-
trace_id=self.
|
531
|
+
trace_id=self._execution_context.trace_id,
|
532
532
|
span_id=self._initial_state.meta.span_id,
|
533
533
|
body=WorkflowExecutionRejectedBody(
|
534
534
|
workflow_definition=self.workflow.__class__,
|
535
535
|
error=error,
|
536
536
|
),
|
537
|
-
parent=self.
|
537
|
+
parent=self._execution_context.parent_context,
|
538
538
|
)
|
539
539
|
|
540
540
|
def _resume_workflow_event(self) -> WorkflowExecutionResumedEvent:
|
541
541
|
return WorkflowExecutionResumedEvent(
|
542
|
-
trace_id=self.
|
542
|
+
trace_id=self._execution_context.trace_id,
|
543
543
|
span_id=self._initial_state.meta.span_id,
|
544
544
|
body=WorkflowExecutionResumedBody(
|
545
545
|
workflow_definition=self.workflow.__class__,
|
546
546
|
),
|
547
|
+
parent=self._execution_context.parent_context,
|
547
548
|
)
|
548
549
|
|
549
550
|
def _pause_workflow_event(self, external_inputs: Iterable[ExternalInputReference]) -> WorkflowExecutionPausedEvent:
|
550
551
|
return WorkflowExecutionPausedEvent(
|
551
|
-
trace_id=self.
|
552
|
+
trace_id=self._execution_context.trace_id,
|
552
553
|
span_id=self._initial_state.meta.span_id,
|
553
554
|
body=WorkflowExecutionPausedBody(
|
554
555
|
workflow_definition=self.workflow.__class__,
|
555
556
|
external_inputs=external_inputs,
|
556
557
|
),
|
557
|
-
parent=self.
|
558
|
+
parent=self._execution_context.parent_context,
|
558
559
|
)
|
559
560
|
|
560
561
|
def _stream(self) -> None:
|
@@ -564,13 +565,13 @@ class WorkflowRunner(Generic[StateType]):
|
|
564
565
|
current_parent = WorkflowParentContext(
|
565
566
|
span_id=self._initial_state.meta.span_id,
|
566
567
|
workflow_definition=self.workflow.__class__,
|
567
|
-
parent=self.
|
568
|
+
parent=self._execution_context.parent_context,
|
568
569
|
type="WORKFLOW",
|
569
570
|
)
|
570
571
|
for node_cls in self._entrypoints:
|
571
572
|
try:
|
572
573
|
if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
|
573
|
-
with execution_context(parent_context=current_parent, trace_id=self.
|
574
|
+
with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
|
574
575
|
self._run_node_if_ready(self._initial_state, node_cls)
|
575
576
|
else:
|
576
577
|
self._concurrency_queue.put((self._initial_state, node_cls, None))
|
@@ -600,7 +601,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
600
601
|
|
601
602
|
self._workflow_event_outer_queue.put(event)
|
602
603
|
|
603
|
-
with execution_context(parent_context=current_parent, trace_id=self.
|
604
|
+
with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
|
604
605
|
rejection_error = self._handle_work_item_event(event)
|
605
606
|
|
606
607
|
if rejection_error:
|
@@ -611,7 +612,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
611
612
|
while event := self._workflow_event_inner_queue.get_nowait():
|
612
613
|
self._workflow_event_outer_queue.put(event)
|
613
614
|
|
614
|
-
with execution_context(parent_context=current_parent, trace_id=self.
|
615
|
+
with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
|
615
616
|
rejection_error = self._handle_work_item_event(event)
|
616
617
|
|
617
618
|
if rejection_error:
|