vellum-ai 0.10.9__py3-none-any.whl → 0.11.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (123) hide show
  1. vellum/__init__.py +16 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +28 -0
  4. vellum/client/types/test_suite_run_exec_config.py +7 -1
  5. vellum/client/types/test_suite_run_exec_config_request.py +8 -0
  6. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config.py +31 -0
  7. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_data.py +27 -0
  8. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_data_request.py +27 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_request.py +31 -0
  10. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config.py +31 -0
  11. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_data.py +27 -0
  12. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_data_request.py +27 -0
  13. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_request.py +31 -0
  14. vellum/evaluations/resources.py +7 -12
  15. vellum/evaluations/utils/env.py +1 -3
  16. vellum/evaluations/utils/paginator.py +0 -1
  17. vellum/evaluations/utils/typing.py +1 -1
  18. vellum/evaluations/utils/uuid.py +1 -1
  19. vellum/plugins/vellum_mypy.py +3 -1
  20. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config.py +3 -0
  21. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_data.py +3 -0
  22. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_data_request.py +3 -0
  23. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_request.py +3 -0
  24. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config.py +3 -0
  25. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_data.py +3 -0
  26. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_data_request.py +3 -0
  27. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_request.py +3 -0
  28. vellum/workflows/context.py +42 -0
  29. vellum/workflows/events/node.py +7 -6
  30. vellum/workflows/events/tests/test_event.py +0 -1
  31. vellum/workflows/events/types.py +0 -1
  32. vellum/workflows/events/workflow.py +19 -1
  33. vellum/workflows/nodes/bases/base.py +17 -56
  34. vellum/workflows/nodes/bases/tests/test_base_node.py +0 -1
  35. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +13 -7
  36. vellum/workflows/nodes/core/templating_node/node.py +1 -0
  37. vellum/workflows/nodes/core/try_node/node.py +2 -2
  38. vellum/workflows/nodes/core/try_node/tests/test_node.py +1 -3
  39. vellum/workflows/nodes/displayable/api_node/node.py +3 -2
  40. vellum/workflows/nodes/displayable/bases/api_node/node.py +1 -1
  41. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +0 -1
  42. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +9 -1
  43. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +12 -2
  44. vellum/workflows/nodes/displayable/bases/search_node.py +0 -1
  45. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +0 -1
  46. vellum/workflows/nodes/displayable/code_execution_node/utils.py +3 -2
  47. vellum/workflows/nodes/displayable/conditional_node/node.py +1 -1
  48. vellum/workflows/nodes/displayable/guardrail_node/node.py +0 -1
  49. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +1 -0
  50. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +3 -1
  51. vellum/workflows/nodes/displayable/search_node/node.py +1 -0
  52. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +13 -3
  53. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +10 -7
  54. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +0 -1
  55. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +1 -1
  56. vellum/workflows/outputs/base.py +2 -4
  57. vellum/workflows/ports/node_ports.py +1 -1
  58. vellum/workflows/runner/runner.py +167 -202
  59. vellum/workflows/state/base.py +0 -2
  60. vellum/workflows/types/core.py +1 -0
  61. vellum/workflows/types/tests/test_utils.py +1 -0
  62. vellum/workflows/types/utils.py +0 -1
  63. vellum/workflows/utils/functions.py +74 -0
  64. vellum/workflows/utils/tests/test_functions.py +171 -0
  65. vellum/workflows/utils/tests/test_vellum_variables.py +0 -1
  66. vellum/workflows/utils/vellum_variables.py +2 -2
  67. vellum/workflows/workflows/base.py +74 -34
  68. vellum/workflows/workflows/event_filters.py +7 -12
  69. {vellum_ai-0.10.9.dist-info → vellum_ai-0.11.1.dist-info}/METADATA +1 -1
  70. {vellum_ai-0.10.9.dist-info → vellum_ai-0.11.1.dist-info}/RECORD +122 -99
  71. vellum_cli/__init__.py +147 -13
  72. vellum_cli/config.py +0 -1
  73. vellum_cli/image_push.py +1 -1
  74. vellum_cli/pull.py +31 -19
  75. vellum_cli/push.py +9 -10
  76. vellum_cli/tests/__init__.py +0 -0
  77. vellum_cli/tests/conftest.py +40 -0
  78. vellum_cli/tests/test_main.py +11 -0
  79. vellum_cli/tests/test_pull.py +143 -71
  80. vellum_cli/tests/test_push.py +173 -0
  81. vellum_ee/workflows/display/base.py +1 -0
  82. vellum_ee/workflows/display/nodes/base_node_display.py +3 -2
  83. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +2 -2
  84. vellum_ee/workflows/display/nodes/get_node_display_class.py +1 -1
  85. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +1 -1
  86. vellum_ee/workflows/display/nodes/vellum/__init__.py +1 -1
  87. vellum_ee/workflows/display/nodes/vellum/api_node.py +54 -58
  88. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +39 -22
  89. vellum_ee/workflows/display/nodes/vellum/error_node.py +3 -3
  90. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +0 -2
  91. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +1 -1
  92. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +1 -1
  93. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +4 -2
  94. vellum_ee/workflows/display/nodes/vellum/map_node.py +11 -5
  95. vellum_ee/workflows/display/nodes/vellum/merge_node.py +2 -2
  96. vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -3
  97. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +1 -1
  98. vellum_ee/workflows/display/nodes/vellum/search_node.py +1 -1
  99. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +1 -1
  100. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  101. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +5 -5
  102. vellum_ee/workflows/display/nodes/vellum/utils.py +30 -10
  103. vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +45 -0
  104. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +42 -25
  105. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +13 -39
  106. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +2 -2
  107. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +62 -58
  108. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +25 -4
  109. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +2 -1
  110. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +2 -2
  111. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +2 -2
  112. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -1
  113. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -1
  114. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -2
  115. vellum_ee/workflows/display/types.py +4 -4
  116. vellum_ee/workflows/display/utils/vellum.py +2 -6
  117. vellum_ee/workflows/display/vellum.py +1 -1
  118. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +4 -1
  119. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +12 -5
  120. vellum/workflows/runner/types.py +0 -16
  121. {vellum_ai-0.10.9.dist-info → vellum_ai-0.11.1.dist-info}/LICENSE +0 -0
  122. {vellum_ai-0.10.9.dist-info → vellum_ai-0.11.1.dist-info}/WHEEL +0 -0
  123. {vellum_ai-0.10.9.dist-info → vellum_ai-0.11.1.dist-info}/entry_points.txt +0 -0
@@ -7,6 +7,7 @@ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import UNDEF
10
+ from vellum.workflows.context import execution_context, get_parent_context
10
11
  from vellum.workflows.descriptors.base import BaseDescriptor
11
12
  from vellum.workflows.edges.edge import Edge
12
13
  from vellum.workflows.errors import VellumError, VellumErrorCode
@@ -28,7 +29,7 @@ from vellum.workflows.events.node import (
28
29
  NodeExecutionRejectedBody,
29
30
  NodeExecutionStreamingBody,
30
31
  )
31
- from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
32
+ from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, WorkflowParentContext
32
33
  from vellum.workflows.events.workflow import (
33
34
  WorkflowExecutionFulfilledBody,
34
35
  WorkflowExecutionInitiatedBody,
@@ -37,6 +38,8 @@ from vellum.workflows.events.workflow import (
37
38
  WorkflowExecutionRejectedBody,
38
39
  WorkflowExecutionResumedBody,
39
40
  WorkflowExecutionResumedEvent,
41
+ WorkflowExecutionSnapshottedBody,
42
+ WorkflowExecutionSnapshottedEvent,
40
43
  WorkflowExecutionStreamingBody,
41
44
  )
42
45
  from vellum.workflows.exceptions import NodeException
@@ -45,7 +48,6 @@ from vellum.workflows.outputs import BaseOutputs
45
48
  from vellum.workflows.outputs.base import BaseOutput
46
49
  from vellum.workflows.ports.port import Port
47
50
  from vellum.workflows.references import ExternalInputReference, OutputReference
48
- from vellum.workflows.runner.types import WorkItemEvent
49
51
  from vellum.workflows.state.base import BaseState
50
52
  from vellum.workflows.types.generics import OutputsType, StateType, WorkflowInputsType
51
53
 
@@ -73,11 +75,10 @@ class WorkflowRunner(Generic[StateType]):
73
75
  parent_context: Optional[ParentContext] = None,
74
76
  ):
75
77
  if state and external_inputs:
76
- raise ValueError(
77
- "Can only run a Workflow providing one of state or external inputs, not both"
78
- )
78
+ raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
79
79
 
80
80
  self.workflow = workflow
81
+ self._is_resuming = False
81
82
  if entrypoint_nodes:
82
83
  if len(list(entrypoint_nodes)) > 1:
83
84
  raise ValueError("Cannot resume from multiple nodes")
@@ -100,10 +101,9 @@ class WorkflowRunner(Generic[StateType]):
100
101
  for ei in external_inputs
101
102
  if issubclass(ei.inputs_class.__parent_class__, BaseNode)
102
103
  ]
104
+ self._is_resuming = True
103
105
  else:
104
- normalized_inputs = (
105
- deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
106
- )
106
+ normalized_inputs = deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
107
107
  if state:
108
108
  self._initial_state = deepcopy(state)
109
109
  self._initial_state.meta.workflow_inputs = normalized_inputs
@@ -111,24 +111,42 @@ class WorkflowRunner(Generic[StateType]):
111
111
  self._initial_state = self.workflow.get_default_state(normalized_inputs)
112
112
  self._entrypoints = self.workflow.get_entrypoints()
113
113
 
114
- self._work_item_event_queue: Queue[WorkItemEvent[StateType]] = Queue()
115
- self._workflow_event_queue: Queue[WorkflowEvent] = Queue()
114
+ # This queue is responsible for sending events from WorkflowRunner to the outside world
115
+ self._workflow_event_outer_queue: Queue[WorkflowEvent] = Queue()
116
+
117
+ # This queue is responsible for sending events from the inner worker threads to WorkflowRunner
118
+ self._workflow_event_inner_queue: Queue[WorkflowEvent] = Queue()
119
+
120
+ # This queue is responsible for sending events from WorkflowRunner to the background thread
121
+ # for user defined emitters
116
122
  self._background_thread_queue: Queue[BackgroundThreadItem] = Queue()
123
+
117
124
  self._dependencies: Dict[Type[BaseNode], Set[Type[BaseNode]]] = defaultdict(set)
118
125
  self._state_forks: Set[StateType] = {self._initial_state}
119
126
 
120
127
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
121
128
  self._cancel_signal = cancel_signal
122
- self._parent_context = parent_context
129
+ self._parent_context = get_parent_context() or parent_context
123
130
 
124
131
  setattr(
125
132
  self._initial_state,
126
133
  "__snapshot_callback__",
127
134
  lambda s: self._snapshot_state(s),
128
135
  )
129
- self.workflow.context._register_event_queue(self._workflow_event_queue)
136
+ self.workflow.context._register_event_queue(self._workflow_event_inner_queue)
130
137
 
131
138
  def _snapshot_state(self, state: StateType) -> StateType:
139
+ self._workflow_event_inner_queue.put(
140
+ WorkflowExecutionSnapshottedEvent(
141
+ trace_id=state.meta.trace_id,
142
+ span_id=state.meta.span_id,
143
+ body=WorkflowExecutionSnapshottedBody(
144
+ workflow_definition=self.workflow.__class__,
145
+ state=state,
146
+ ),
147
+ parent=self._parent_context,
148
+ )
149
+ )
132
150
  self.workflow._store.append_state_snapshot(state)
133
151
  self._background_thread_queue.put(state)
134
152
  return state
@@ -139,30 +157,29 @@ class WorkflowRunner(Generic[StateType]):
139
157
  return event
140
158
 
141
159
  def _run_work_item(self, node: BaseNode[StateType], span_id: UUID) -> None:
142
- self._work_item_event_queue.put(
143
- WorkItemEvent(
144
- node=node,
145
- event=NodeExecutionInitiatedEvent(
146
- trace_id=node.state.meta.trace_id,
147
- span_id=span_id,
148
- body=NodeExecutionInitiatedBody(
149
- node_definition=node.__class__,
150
- inputs=node._inputs,
151
- ),
152
- parent=WorkflowParentContext(
153
- span_id=span_id,
154
- workflow_definition=self.workflow.__class__,
155
- parent=self._parent_context,
156
- type="WORKFLOW",
157
- ),
160
+ parent_context = get_parent_context()
161
+ self._workflow_event_inner_queue.put(
162
+ NodeExecutionInitiatedEvent(
163
+ trace_id=node.state.meta.trace_id,
164
+ span_id=span_id,
165
+ body=NodeExecutionInitiatedBody(
166
+ node_definition=node.__class__,
167
+ inputs=node._inputs,
158
168
  ),
169
+ parent=parent_context,
159
170
  )
160
171
  )
161
172
 
162
173
  logger.debug(f"Started running node: {node.__class__.__name__}")
163
174
 
164
175
  try:
165
- node_run_response = node.run()
176
+ updated_parent_context = NodeParentContext(
177
+ span_id=span_id,
178
+ node_definition=node.__class__,
179
+ parent=parent_context,
180
+ )
181
+ with execution_context(parent_context=updated_parent_context):
182
+ node_run_response = node.run()
166
183
  ports = node.Ports()
167
184
  if not isinstance(node_run_response, (BaseOutputs, Iterator)):
168
185
  raise NodeException(
@@ -183,6 +200,7 @@ class WorkflowRunner(Generic[StateType]):
183
200
  outputs = node.Outputs()
184
201
 
185
202
  def initiate_node_streaming_output(output: BaseOutput) -> None:
203
+ parent_context = get_parent_context()
186
204
  streaming_output_queues[output.name] = Queue()
187
205
  output_descriptor = OutputReference(
188
206
  name=output.name,
@@ -190,44 +208,34 @@ class WorkflowRunner(Generic[StateType]):
190
208
  instance=None,
191
209
  outputs_class=node.Outputs,
192
210
  )
193
- node.state.meta.node_outputs[output_descriptor] = (
194
- streaming_output_queues[output.name]
195
- )
211
+ node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
196
212
  initiated_output: BaseOutput = BaseOutput(name=output.name)
197
213
  initiated_ports = initiated_output > ports
198
- self._work_item_event_queue.put(
199
- WorkItemEvent(
200
- node=node,
201
- event=NodeExecutionStreamingEvent(
202
- trace_id=node.state.meta.trace_id,
203
- span_id=span_id,
204
- body=NodeExecutionStreamingBody(
205
- node_definition=node.__class__,
206
- output=initiated_output,
207
- invoked_ports=initiated_ports,
208
- ),
209
- parent=WorkflowParentContext(
210
- span_id=span_id,
211
- workflow_definition=self.workflow.__class__,
212
- parent=self._parent_context,
213
- ),
214
+ self._workflow_event_inner_queue.put(
215
+ NodeExecutionStreamingEvent(
216
+ trace_id=node.state.meta.trace_id,
217
+ span_id=span_id,
218
+ body=NodeExecutionStreamingBody(
219
+ node_definition=node.__class__,
220
+ output=initiated_output,
221
+ invoked_ports=initiated_ports,
214
222
  ),
215
- )
223
+ parent=parent_context,
224
+ ),
216
225
  )
217
226
 
218
- for output in node_run_response:
219
- invoked_ports = output > ports
220
- if output.is_initiated:
221
- initiate_node_streaming_output(output)
222
- elif output.is_streaming:
223
- if output.name not in streaming_output_queues:
227
+ with execution_context(parent_context=updated_parent_context):
228
+ for output in node_run_response:
229
+ invoked_ports = output > ports
230
+ if output.is_initiated:
224
231
  initiate_node_streaming_output(output)
232
+ elif output.is_streaming:
233
+ if output.name not in streaming_output_queues:
234
+ initiate_node_streaming_output(output)
225
235
 
226
- streaming_output_queues[output.name].put(output.delta)
227
- self._work_item_event_queue.put(
228
- WorkItemEvent(
229
- node=node,
230
- event=NodeExecutionStreamingEvent(
236
+ streaming_output_queues[output.name].put(output.delta)
237
+ self._workflow_event_inner_queue.put(
238
+ NodeExecutionStreamingEvent(
231
239
  trace_id=node.state.meta.trace_id,
232
240
  span_id=span_id,
233
241
  body=NodeExecutionStreamingBody(
@@ -235,23 +243,16 @@ class WorkflowRunner(Generic[StateType]):
235
243
  output=output,
236
244
  invoked_ports=invoked_ports,
237
245
  ),
238
- parent=WorkflowParentContext(
239
- span_id=span_id,
240
- workflow_definition=self.workflow.__class__,
241
- parent=self._parent_context,
242
- ),
246
+ parent=parent_context,
243
247
  ),
244
248
  )
245
- )
246
- elif output.is_fulfilled:
247
- if output.name in streaming_output_queues:
248
- streaming_output_queues[output.name].put(UNDEF)
249
-
250
- setattr(outputs, output.name, output.value)
251
- self._work_item_event_queue.put(
252
- WorkItemEvent(
253
- node=node,
254
- event=NodeExecutionStreamingEvent(
249
+ elif output.is_fulfilled:
250
+ if output.name in streaming_output_queues:
251
+ streaming_output_queues[output.name].put(UNDEF)
252
+
253
+ setattr(outputs, output.name, output.value)
254
+ self._workflow_event_inner_queue.put(
255
+ NodeExecutionStreamingEvent(
255
256
  trace_id=node.state.meta.trace_id,
256
257
  span_id=span_id,
257
258
  body=NodeExecutionStreamingBody(
@@ -259,14 +260,12 @@ class WorkflowRunner(Generic[StateType]):
259
260
  output=output,
260
261
  invoked_ports=invoked_ports,
261
262
  ),
262
- parent=WorkflowParentContext(
263
- span_id=span_id,
264
- workflow_definition=self.workflow.__class__,
265
- parent=self._parent_context,
266
- ),
267
- ),
263
+ parent=parent_context,
264
+ )
268
265
  )
269
- )
266
+
267
+ invoked_ports = ports(outputs, node.state)
268
+ node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
270
269
 
271
270
  for descriptor, output_value in outputs:
272
271
  if output_value is UNDEF:
@@ -276,81 +275,62 @@ class WorkflowRunner(Generic[StateType]):
276
275
 
277
276
  node.state.meta.node_outputs[descriptor] = output_value
278
277
 
279
- invoked_ports = ports(outputs, node.state)
280
- node.state.meta.node_execution_cache.fulfill_node_execution(
281
- node.__class__, span_id
282
- )
283
-
284
- self._work_item_event_queue.put(
285
- WorkItemEvent(
286
- node=node,
287
- event=NodeExecutionFulfilledEvent(
288
- trace_id=node.state.meta.trace_id,
289
- span_id=span_id,
290
- body=NodeExecutionFulfilledBody(
291
- node_definition=node.__class__,
292
- outputs=outputs,
293
- invoked_ports=invoked_ports,
294
- ),
295
- parent=WorkflowParentContext(
296
- span_id=span_id,
297
- workflow_definition=self.workflow.__class__,
298
- parent=self._parent_context,
299
- ),
278
+ self._workflow_event_inner_queue.put(
279
+ NodeExecutionFulfilledEvent(
280
+ trace_id=node.state.meta.trace_id,
281
+ span_id=span_id,
282
+ body=NodeExecutionFulfilledBody(
283
+ node_definition=node.__class__,
284
+ outputs=outputs,
285
+ invoked_ports=invoked_ports,
300
286
  ),
287
+ parent=parent_context,
301
288
  )
302
289
  )
303
290
  except NodeException as e:
304
- self._work_item_event_queue.put(
305
- WorkItemEvent(
306
- node=node,
307
- event=NodeExecutionRejectedEvent(
308
- trace_id=node.state.meta.trace_id,
291
+ self._workflow_event_inner_queue.put(
292
+ NodeExecutionRejectedEvent(
293
+ trace_id=node.state.meta.trace_id,
294
+ span_id=span_id,
295
+ body=NodeExecutionRejectedBody(
296
+ node_definition=node.__class__,
297
+ error=e.error,
298
+ ),
299
+ parent=WorkflowParentContext(
309
300
  span_id=span_id,
310
- body=NodeExecutionRejectedBody(
311
- node_definition=node.__class__,
312
- error=e.error,
313
- ),
314
- parent=WorkflowParentContext(
315
- span_id=span_id,
316
- workflow_definition=self.workflow.__class__,
317
- parent=self._parent_context,
318
- ),
301
+ workflow_definition=self.workflow.__class__,
302
+ parent=self._parent_context,
319
303
  ),
320
304
  )
321
305
  )
322
306
  except Exception as e:
323
- logger.exception(
324
- f"An unexpected error occurred while running node {node.__class__.__name__}"
325
- )
307
+ logger.exception(f"An unexpected error occurred while running node {node.__class__.__name__}")
326
308
 
327
- self._work_item_event_queue.put(
328
- WorkItemEvent(
329
- node=node,
330
- event=NodeExecutionRejectedEvent(
331
- trace_id=node.state.meta.trace_id,
332
- span_id=span_id,
333
- body=NodeExecutionRejectedBody(
334
- node_definition=node.__class__,
335
- error=VellumError(
336
- message=str(e),
337
- code=VellumErrorCode.INTERNAL_ERROR,
338
- ),
339
- ),
340
- parent=WorkflowParentContext(
341
- span_id=span_id,
342
- workflow_definition=self.workflow.__class__,
343
- parent=self._parent_context,
309
+ self._workflow_event_inner_queue.put(
310
+ NodeExecutionRejectedEvent(
311
+ trace_id=node.state.meta.trace_id,
312
+ span_id=span_id,
313
+ body=NodeExecutionRejectedBody(
314
+ node_definition=node.__class__,
315
+ error=VellumError(
316
+ message=str(e),
317
+ code=VellumErrorCode.INTERNAL_ERROR,
344
318
  ),
345
319
  ),
346
- )
320
+ parent=parent_context,
321
+ ),
347
322
  )
348
323
 
349
324
  logger.debug(f"Finished running node: {node.__class__.__name__}")
350
325
 
351
- def _handle_invoked_ports(
352
- self, state: StateType, ports: Optional[Iterable[Port]]
353
- ) -> None:
326
+ def _context_run_work_item(self, node: BaseNode[StateType], span_id: UUID, parent_context=None) -> None:
327
+ if parent_context is None:
328
+ parent_context = get_parent_context() or self._parent_context
329
+
330
+ with execution_context(parent_context=parent_context):
331
+ self._run_work_item(node, span_id)
332
+
333
+ def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
354
334
  if not ports:
355
335
  return
356
336
 
@@ -380,31 +360,24 @@ class WorkflowRunner(Generic[StateType]):
380
360
  return
381
361
 
382
362
  all_deps = self._dependencies[node_class]
383
- node_span_id = state.meta.node_execution_cache.queue_node_execution(
384
- node_class, all_deps, invoked_by
385
- )
363
+ node_span_id = state.meta.node_execution_cache.queue_node_execution(node_class, all_deps, invoked_by)
386
364
  if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
387
365
  return
388
366
 
367
+ current_parent = get_parent_context()
389
368
  node = node_class(state=state, context=self.workflow.context)
390
- state.meta.node_execution_cache.initiate_node_execution(
391
- node_class, node_span_id
392
- )
369
+ state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
393
370
  self._active_nodes_by_execution_id[node_span_id] = node
394
371
 
395
372
  worker_thread = Thread(
396
- target=self._run_work_item,
397
- kwargs={"node": node, "span_id": node_span_id},
373
+ target=self._context_run_work_item,
374
+ kwargs={"node": node, "span_id": node_span_id, "parent_context": current_parent},
398
375
  )
399
376
  worker_thread.start()
400
377
 
401
- def _handle_work_item_event(
402
- self, work_item_event: WorkItemEvent[StateType]
403
- ) -> Optional[VellumError]:
404
- node = work_item_event.node
405
- event = work_item_event.event
406
-
407
- if event.name == "node.execution.initiated":
378
+ def _handle_work_item_event(self, event: WorkflowEvent) -> Optional[VellumError]:
379
+ node = self._active_nodes_by_execution_id.get(event.span_id)
380
+ if not node:
408
381
  return None
409
382
 
410
383
  if event.name == "node.execution.rejected":
@@ -416,15 +389,12 @@ class WorkflowRunner(Generic[StateType]):
416
389
  node_output_descriptor = workflow_output_descriptor.instance
417
390
  if not isinstance(node_output_descriptor, OutputReference):
418
391
  continue
419
- if (
420
- node_output_descriptor.outputs_class
421
- != event.node_definition.Outputs
422
- ):
392
+ if node_output_descriptor.outputs_class != event.node_definition.Outputs:
423
393
  continue
424
394
  if node_output_descriptor.name != event.output.name:
425
395
  continue
426
396
 
427
- self._workflow_event_queue.put(
397
+ self._workflow_event_outer_queue.put(
428
398
  self._stream_workflow_event(
429
399
  BaseOutput(
430
400
  name=workflow_output_descriptor.name,
@@ -444,7 +414,7 @@ class WorkflowRunner(Generic[StateType]):
444
414
 
445
415
  return None
446
416
 
447
- raise ValueError(f"Invalid event name: {event.name}")
417
+ return None
448
418
 
449
419
  def _initiate_workflow_event(self) -> WorkflowExecutionInitiatedEvent:
450
420
  return WorkflowExecutionInitiatedEvent(
@@ -457,9 +427,7 @@ class WorkflowRunner(Generic[StateType]):
457
427
  parent=self._parent_context,
458
428
  )
459
429
 
460
- def _stream_workflow_event(
461
- self, output: BaseOutput
462
- ) -> WorkflowExecutionStreamingEvent:
430
+ def _stream_workflow_event(self, output: BaseOutput) -> WorkflowExecutionStreamingEvent:
463
431
  return WorkflowExecutionStreamingEvent(
464
432
  trace_id=self._initial_state.meta.trace_id,
465
433
  span_id=self._initial_state.meta.span_id,
@@ -470,9 +438,7 @@ class WorkflowRunner(Generic[StateType]):
470
438
  parent=self._parent_context,
471
439
  )
472
440
 
473
- def _fulfill_workflow_event(
474
- self, outputs: OutputsType
475
- ) -> WorkflowExecutionFulfilledEvent:
441
+ def _fulfill_workflow_event(self, outputs: OutputsType) -> WorkflowExecutionFulfilledEvent:
476
442
  return WorkflowExecutionFulfilledEvent(
477
443
  trace_id=self._initial_state.meta.trace_id,
478
444
  span_id=self._initial_state.meta.span_id,
@@ -483,9 +449,7 @@ class WorkflowRunner(Generic[StateType]):
483
449
  parent=self._parent_context,
484
450
  )
485
451
 
486
- def _reject_workflow_event(
487
- self, error: VellumError
488
- ) -> WorkflowExecutionRejectedEvent:
452
+ def _reject_workflow_event(self, error: VellumError) -> WorkflowExecutionRejectedEvent:
489
453
  return WorkflowExecutionRejectedEvent(
490
454
  trace_id=self._initial_state.meta.trace_id,
491
455
  span_id=self._initial_state.meta.span_id,
@@ -505,9 +469,7 @@ class WorkflowRunner(Generic[StateType]):
505
469
  ),
506
470
  )
507
471
 
508
- def _pause_workflow_event(
509
- self, external_inputs: Iterable[ExternalInputReference]
510
- ) -> WorkflowExecutionPausedEvent:
472
+ def _pause_workflow_event(self, external_inputs: Iterable[ExternalInputReference]) -> WorkflowExecutionPausedEvent:
511
473
  return WorkflowExecutionPausedEvent(
512
474
  trace_id=self._initial_state.meta.trace_id,
513
475
  span_id=self._initial_state.meta.span_id,
@@ -522,7 +484,7 @@ class WorkflowRunner(Generic[StateType]):
522
484
  # TODO: We should likely handle this during initialization
523
485
  # https://app.shortcut.com/vellum/story/4327
524
486
  if not self._entrypoints:
525
- self._workflow_event_queue.put(
487
+ self._workflow_event_outer_queue.put(
526
488
  self._reject_workflow_event(
527
489
  VellumError(
528
490
  message="No entrypoints defined",
@@ -535,20 +497,25 @@ class WorkflowRunner(Generic[StateType]):
535
497
  for edge in self.workflow.get_edges():
536
498
  self._dependencies[edge.to_node].add(edge.from_port.node_class)
537
499
 
500
+ current_parent = WorkflowParentContext(
501
+ span_id=self._initial_state.meta.span_id,
502
+ workflow_definition=self.workflow.__class__,
503
+ parent=self._parent_context,
504
+ type="WORKFLOW",
505
+ )
538
506
  for node_cls in self._entrypoints:
539
507
  try:
540
- self._run_node_if_ready(self._initial_state, node_cls)
508
+ with execution_context(parent_context=current_parent):
509
+ self._run_node_if_ready(self._initial_state, node_cls)
541
510
  except NodeException as e:
542
- self._workflow_event_queue.put(self._reject_workflow_event(e.error))
511
+ self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error))
543
512
  return
544
513
  except Exception:
545
514
  err_message = f"An unexpected error occurred while initializing node {node_cls.__name__}"
546
515
  logger.exception(err_message)
547
- self._workflow_event_queue.put(
516
+ self._workflow_event_outer_queue.put(
548
517
  self._reject_workflow_event(
549
- VellumError(
550
- code=VellumErrorCode.INTERNAL_ERROR, message=err_message
551
- ),
518
+ VellumError(code=VellumErrorCode.INTERNAL_ERROR, message=err_message),
552
519
  )
553
520
  )
554
521
  return
@@ -559,21 +526,24 @@ class WorkflowRunner(Generic[StateType]):
559
526
  if not self._active_nodes_by_execution_id:
560
527
  break
561
528
 
562
- work_item_event = self._work_item_event_queue.get()
563
- event = work_item_event.event
529
+ event = self._workflow_event_inner_queue.get()
530
+
531
+ self._workflow_event_outer_queue.put(event)
564
532
 
565
- self._workflow_event_queue.put(event)
533
+ with execution_context(parent_context=current_parent):
534
+ rejection_error = self._handle_work_item_event(event)
566
535
 
567
- rejection_error = self._handle_work_item_event(work_item_event)
568
536
  if rejection_error:
569
537
  break
570
538
 
571
539
  # Handle any remaining events
572
540
  try:
573
- while work_item_event := self._work_item_event_queue.get_nowait():
574
- self._workflow_event_queue.put(work_item_event.event)
541
+ while event := self._workflow_event_inner_queue.get_nowait():
542
+ self._workflow_event_outer_queue.put(event)
543
+
544
+ with execution_context(parent_context=current_parent):
545
+ rejection_error = self._handle_work_item_event(event)
575
546
 
576
- rejection_error = self._handle_work_item_event(work_item_event)
577
547
  if rejection_error:
578
548
  break
579
549
  except Empty:
@@ -589,14 +559,13 @@ class WorkflowRunner(Generic[StateType]):
589
559
  if node_input_value is UNDEF
590
560
  }
591
561
  if unresolved_external_inputs:
592
- self._workflow_event_queue.put(
562
+ self._workflow_event_outer_queue.put(
593
563
  self._pause_workflow_event(unresolved_external_inputs),
594
564
  )
595
565
  return
596
566
 
597
- final_state.meta.is_terminated = True
598
567
  if rejection_error:
599
- self._workflow_event_queue.put(self._reject_workflow_event(rejection_error))
568
+ self._workflow_event_outer_queue.put(self._reject_workflow_event(rejection_error))
600
569
  return
601
570
 
602
571
  fulfilled_outputs = self.workflow.Outputs()
@@ -610,7 +579,7 @@ class WorkflowRunner(Generic[StateType]):
610
579
  descriptor.instance.resolve(final_state),
611
580
  )
612
581
 
613
- self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
582
+ self._workflow_event_outer_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
614
583
 
615
584
  def _run_background_thread(self) -> None:
616
585
  state_class = self.workflow.get_state_class()
@@ -631,7 +600,7 @@ class WorkflowRunner(Generic[StateType]):
631
600
  return
632
601
 
633
602
  self._cancel_signal.wait()
634
- self._workflow_event_queue.put(
603
+ self._workflow_event_outer_queue.put(
635
604
  self._reject_workflow_event(
636
605
  VellumError(
637
606
  code=VellumErrorCode.WORKFLOW_CANCELLED,
@@ -664,16 +633,12 @@ class WorkflowRunner(Generic[StateType]):
664
633
  cancel_thread.start()
665
634
 
666
635
  event: WorkflowEvent
667
- if (
668
- self._initial_state.meta.is_terminated
669
- or self._initial_state.meta.is_terminated is None
670
- ):
671
- event = self._initiate_workflow_event()
672
- else:
636
+ if self._is_resuming:
673
637
  event = self._resume_workflow_event()
638
+ else:
639
+ event = self._initiate_workflow_event()
674
640
 
675
641
  yield self._emit_event(event)
676
- self._initial_state.meta.is_terminated = False
677
642
 
678
643
  # The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
679
644
  stream_thread = Thread(
@@ -684,7 +649,7 @@ class WorkflowRunner(Generic[StateType]):
684
649
 
685
650
  while stream_thread.is_alive():
686
651
  try:
687
- event = self._workflow_event_queue.get(timeout=0.1)
652
+ event = self._workflow_event_outer_queue.get(timeout=0.1)
688
653
  except Empty:
689
654
  continue
690
655
 
@@ -694,7 +659,7 @@ class WorkflowRunner(Generic[StateType]):
694
659
  break
695
660
 
696
661
  try:
697
- while event := self._workflow_event_queue.get_nowait():
662
+ while event := self._workflow_event_outer_queue.get_nowait():
698
663
  yield self._emit_event(event)
699
664
  except Empty:
700
665
  pass
@@ -192,7 +192,6 @@ class StateMeta(UniversalBaseModel):
192
192
  node_outputs: Dict[OutputReference, Any] = field(default_factory=dict)
193
193
  node_execution_cache: NodeExecutionCache = field(default_factory=NodeExecutionCache)
194
194
  parent: Optional["BaseState"] = None
195
- is_terminated: Optional[bool] = None
196
195
  __snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
197
196
 
198
197
  def model_post_init(self, context: Any) -> None:
@@ -286,7 +285,6 @@ class BaseState(metaclass=_BaseStateMeta):
286
285
  {values}
287
286
  meta:
288
287
  id={self.meta.id}
289
- is_terminated={self.meta.is_terminated}
290
288
  updated_ts={self.meta.updated_ts}
291
289
  node_outputs:{' Empty' if not node_outputs else ''}
292
290
  {node_outputs}