vellum-ai 0.11.0__py3-none-any.whl → 0.11.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (54) hide show
  1. vellum/__init__.py +18 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +30 -0
  4. vellum/client/types/document_document_to_document_index.py +6 -0
  5. vellum/client/types/slim_document.py +2 -2
  6. vellum/client/types/slim_document_document_to_document_index.py +43 -0
  7. vellum/client/types/test_suite_run_exec_config.py +7 -1
  8. vellum/client/types/test_suite_run_exec_config_request.py +8 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config.py +31 -0
  10. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_data.py +27 -0
  11. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_data_request.py +27 -0
  12. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_request.py +31 -0
  13. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config.py +31 -0
  14. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_data.py +27 -0
  15. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_data_request.py +27 -0
  16. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_request.py +31 -0
  17. vellum/types/slim_document_document_to_document_index.py +3 -0
  18. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config.py +3 -0
  19. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_data.py +3 -0
  20. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_data_request.py +3 -0
  21. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_request.py +3 -0
  22. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config.py +3 -0
  23. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_data.py +3 -0
  24. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_data_request.py +3 -0
  25. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_request.py +3 -0
  26. vellum/workflows/context.py +42 -0
  27. vellum/workflows/descriptors/base.py +1 -1
  28. vellum/workflows/nodes/bases/tests/test_base_node.py +1 -1
  29. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +13 -7
  30. vellum/workflows/nodes/core/map_node/node.py +27 -4
  31. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  32. vellum/workflows/nodes/displayable/api_node/node.py +3 -2
  33. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +9 -0
  34. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +10 -1
  35. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +10 -1
  36. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +1 -1
  37. vellum/workflows/runner/runner.py +74 -70
  38. vellum/workflows/workflows/event_filters.py +4 -1
  39. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/METADATA +1 -1
  40. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/RECORD +54 -35
  41. vellum_cli/pull.py +3 -1
  42. vellum_cli/tests/test_pull.py +18 -0
  43. vellum_ee/workflows/display/base.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/api_node.py +53 -54
  45. vellum_ee/workflows/display/nodes/vellum/merge_node.py +20 -1
  46. vellum_ee/workflows/display/nodes/vellum/templating_node.py +15 -4
  47. vellum_ee/workflows/display/nodes/vellum/utils.py +26 -6
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +29 -1
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +2 -2
  50. vellum_ee/workflows/display/vellum.py +4 -4
  51. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +18 -9
  52. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/LICENSE +0 -0
  53. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/WHEEL +0 -0
  54. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/entry_points.txt +0 -0
@@ -3,14 +3,18 @@ from queue import Empty, Queue
3
3
  from threading import Thread
4
4
  from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, overload
5
5
 
6
+ from vellum.workflows.context import execution_context, get_parent_context
6
7
  from vellum.workflows.descriptors.base import BaseDescriptor
7
8
  from vellum.workflows.errors.types import VellumErrorCode
9
+ from vellum.workflows.events.types import ParentContext
8
10
  from vellum.workflows.exceptions import NodeException
9
11
  from vellum.workflows.inputs.base import BaseInputs
10
12
  from vellum.workflows.nodes.bases import BaseNode
11
13
  from vellum.workflows.outputs import BaseOutputs
12
14
  from vellum.workflows.state.base import BaseState
15
+ from vellum.workflows.state.context import WorkflowContext
13
16
  from vellum.workflows.types.generics import NodeType, StateType
17
+ from vellum.workflows.workflows.event_filters import all_workflow_event_filter
14
18
 
15
19
  if TYPE_CHECKING:
16
20
  from vellum.workflows import BaseWorkflow
@@ -53,7 +57,15 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
53
57
  fulfilled_iterations: List[bool] = []
54
58
  for index, item in enumerate(self.items):
55
59
  fulfilled_iterations.append(False)
56
- thread = Thread(target=self._run_subworkflow, kwargs={"item": item, "index": index})
60
+ parent_context = get_parent_context() or self._context.parent_context
61
+ thread = Thread(
62
+ target=self._context_run_subworkflow,
63
+ kwargs={
64
+ "item": item,
65
+ "index": index,
66
+ "parent_context": parent_context,
67
+ },
68
+ )
57
69
  thread.start()
58
70
 
59
71
  try:
@@ -62,6 +74,7 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
62
74
  while map_node_event := self._event_queue.get():
63
75
  index = map_node_event[0]
64
76
  terminal_event = map_node_event[1]
77
+ self._context._emit_subworkflow_event(terminal_event)
65
78
 
66
79
  if terminal_event.name == "workflow.execution.fulfilled":
67
80
  workflow_output_vars = vars(terminal_event.outputs)
@@ -85,12 +98,22 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
85
98
  )
86
99
  except Empty:
87
100
  pass
88
-
89
101
  return self.Outputs(**mapped_items)
90
102
 
103
+ def _context_run_subworkflow(
104
+ self, *, item: MapNodeItemType, index: int, parent_context: Optional[ParentContext] = None
105
+ ) -> None:
106
+ parent_context = parent_context or self._context.parent_context
107
+ with execution_context(parent_context=parent_context):
108
+ self._run_subworkflow(item=item, index=index)
109
+
91
110
  def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
92
- subworkflow = self.subworkflow(parent_state=self.state, context=self._context)
93
- events = subworkflow.stream(inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items))
111
+ context = WorkflowContext(_vellum_client=self._context._vellum_client)
112
+ subworkflow = self.subworkflow(parent_state=self.state, context=context)
113
+ events = subworkflow.stream(
114
+ inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items),
115
+ event_filter=all_workflow_event_filter,
116
+ )
94
117
 
95
118
  for event in events:
96
119
  self._event_queue.put((index, event))
@@ -1,5 +1,6 @@
1
1
  import json
2
2
 
3
+ from vellum.workflows.nodes.bases.base import BaseNode
3
4
  from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
4
5
 
5
6
 
@@ -19,3 +20,23 @@ def test_templating_node__dict_output():
19
20
 
20
21
  # THEN the output is json serializable
21
22
  assert json.loads(outputs.result) == {"key": "value"}
23
+
24
+
25
+ def test_templating_node__execution_count_reference():
26
+ # GIVEN a random node
27
+ class OtherNode(BaseNode):
28
+ pass
29
+
30
+ # AND a templating node that references the execution count of the random node
31
+ class TemplateNode(TemplatingNode):
32
+ template = "{{ total }}"
33
+ inputs = {
34
+ "total": OtherNode.Execution.count,
35
+ }
36
+
37
+ # WHEN the node is run
38
+ node = TemplateNode()
39
+ outputs = node.run()
40
+
41
+ # THEN the output is just the total
42
+ assert outputs.result == "0"
@@ -18,13 +18,14 @@ class APINode(BaseAPINode):
18
18
 
19
19
  authorization_type: Optional[AuthorizationType] = None - The type of authorization to use for the API call.
20
20
  api_key_header_key: Optional[str] = None - The header key to use for the API key authorization.
21
- bearer_token_value: Optional[str] = None - The bearer token value to use for the bearer token authorization.
21
+ bearer_token_value: Optional[Union[str, VellumSecretReference]] = None - The bearer token value to use
22
+ for the bearer token authorization.
22
23
  """
23
24
 
24
25
  authorization_type: Optional[AuthorizationType] = None
25
26
  api_key_header_key: Optional[str] = None
26
27
  api_key_header_value: Optional[Union[str, VellumSecretReference]] = None
27
- bearer_token_value: Optional[str] = None
28
+ bearer_token_value: Optional[Union[str, VellumSecretReference]] = None
28
29
 
29
30
  def run(self) -> BaseAPINode.Outputs:
30
31
  headers = self.headers or {}
@@ -14,7 +14,9 @@ from vellum import (
14
14
  PromptRequestStringInput,
15
15
  VellumVariable,
16
16
  )
17
+ from vellum.client import RequestOptions
17
18
  from vellum.workflows.constants import OMIT
19
+ from vellum.workflows.context import get_parent_context
18
20
  from vellum.workflows.errors import VellumErrorCode
19
21
  from vellum.workflows.exceptions import NodeException
20
22
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
@@ -48,6 +50,13 @@ class BaseInlinePromptNode(BasePromptNode, Generic[StateType]):
48
50
 
49
51
  def _get_prompt_event_stream(self) -> Iterator[AdHocExecutePromptEvent]:
50
52
  input_variables, input_values = self._compile_prompt_inputs()
53
+ current_parent_context = get_parent_context()
54
+ parent_context = current_parent_context.model_dump_json() if current_parent_context else None
55
+ request_options = self.request_options or RequestOptions()
56
+ request_options["additional_body_parameters"] = {
57
+ "execution_context": {"parent_context": parent_context},
58
+ **request_options.get("additional_body_parameters", {}),
59
+ }
51
60
 
52
61
  return self._context.vellum_client.ad_hoc.adhoc_execute_prompt_stream(
53
62
  ml_model=self.ml_model,
@@ -11,7 +11,9 @@ from vellum import (
11
11
  RawPromptExecutionOverridesRequest,
12
12
  StringInputRequest,
13
13
  )
14
+ from vellum.client import RequestOptions
14
15
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
16
+ from vellum.workflows.context import get_parent_context
15
17
  from vellum.workflows.errors import VellumErrorCode
16
18
  from vellum.workflows.exceptions import NodeException
17
19
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
@@ -46,6 +48,13 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
46
48
  metadata: Optional[Dict[str, Optional[Any]]] = OMIT
47
49
 
48
50
  def _get_prompt_event_stream(self) -> Iterator[ExecutePromptEvent]:
51
+ current_parent_context = get_parent_context()
52
+ parent_context = current_parent_context.model_dump() if current_parent_context else None
53
+ request_options = self.request_options or RequestOptions()
54
+ request_options["additional_body_parameters"] = {
55
+ "execution_context": {"parent_context": parent_context},
56
+ **request_options.get("additional_body_parameters", {}),
57
+ }
49
58
  return self._context.vellum_client.execute_prompt_stream(
50
59
  inputs=self._compile_prompt_inputs(),
51
60
  prompt_deployment_id=str(self.deployment) if isinstance(self.deployment, UUID) else None,
@@ -56,7 +65,7 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
56
65
  raw_overrides=self.raw_overrides,
57
66
  expand_raw=self.expand_raw,
58
67
  metadata=self.metadata,
59
- request_options=self.request_options,
68
+ request_options=request_options,
60
69
  )
61
70
 
62
71
  def _compile_prompt_inputs(self) -> List[PromptDeploymentInputRequest]:
@@ -13,6 +13,7 @@ from vellum import (
13
13
  )
14
14
  from vellum.core import RequestOptions
15
15
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
16
+ from vellum.workflows.context import get_parent_context
16
17
  from vellum.workflows.errors import VellumErrorCode
17
18
  from vellum.workflows.exceptions import NodeException
18
19
  from vellum.workflows.nodes.bases.base_subworkflow_node.node import BaseSubworkflowNode
@@ -89,6 +90,13 @@ class SubworkflowDeploymentNode(BaseSubworkflowNode[StateType], Generic[StateTyp
89
90
  return compiled_inputs
90
91
 
91
92
  def run(self) -> Iterator[BaseOutput]:
93
+ current_parent_context = get_parent_context()
94
+ parent_context = current_parent_context.model_dump(mode="json") if current_parent_context else None
95
+ request_options = self.request_options or RequestOptions()
96
+ request_options["additional_body_parameters"] = {
97
+ "execution_context": {"parent_context": parent_context},
98
+ **request_options.get("additional_body_parameters", {}),
99
+ }
92
100
  subworkflow_stream = self._context.vellum_client.execute_workflow_stream(
93
101
  inputs=self._compile_subworkflow_inputs(),
94
102
  workflow_deployment_id=str(self.deployment) if isinstance(self.deployment, UUID) else None,
@@ -97,8 +105,9 @@ class SubworkflowDeploymentNode(BaseSubworkflowNode[StateType], Generic[StateTyp
97
105
  external_id=self.external_id,
98
106
  event_types=["WORKFLOW"],
99
107
  metadata=self.metadata,
100
- request_options=self.request_options,
108
+ request_options=request_options,
101
109
  )
110
+ # for some reason execution context isn't showing as an option? ^ failing mypy
102
111
 
103
112
  outputs: Optional[List[WorkflowOutput]] = None
104
113
  fulfilled_output_names: Set[str] = set()
@@ -75,5 +75,5 @@ def test_text_prompt_deployment_node__basic(vellum_client):
75
75
  prompt_deployment_name="my-deployment",
76
76
  raw_overrides=OMIT,
77
77
  release_tag="LATEST",
78
- request_options=None,
78
+ request_options={"additional_body_parameters": {"execution_context": {"parent_context": None}}},
79
79
  )
@@ -7,6 +7,7 @@ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import UNDEF
10
+ from vellum.workflows.context import execution_context, get_parent_context
10
11
  from vellum.workflows.descriptors.base import BaseDescriptor
11
12
  from vellum.workflows.edges.edge import Edge
12
13
  from vellum.workflows.errors import VellumError, VellumErrorCode
@@ -28,7 +29,7 @@ from vellum.workflows.events.node import (
28
29
  NodeExecutionRejectedBody,
29
30
  NodeExecutionStreamingBody,
30
31
  )
31
- from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
32
+ from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, WorkflowParentContext
32
33
  from vellum.workflows.events.workflow import (
33
34
  WorkflowExecutionFulfilledBody,
34
35
  WorkflowExecutionInitiatedBody,
@@ -125,7 +126,7 @@ class WorkflowRunner(Generic[StateType]):
125
126
 
126
127
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
127
128
  self._cancel_signal = cancel_signal
128
- self._parent_context = parent_context
129
+ self._parent_context = get_parent_context() or parent_context
129
130
 
130
131
  setattr(
131
132
  self._initial_state,
@@ -156,6 +157,7 @@ class WorkflowRunner(Generic[StateType]):
156
157
  return event
157
158
 
158
159
  def _run_work_item(self, node: BaseNode[StateType], span_id: UUID) -> None:
160
+ parent_context = get_parent_context()
159
161
  self._workflow_event_inner_queue.put(
160
162
  NodeExecutionInitiatedEvent(
161
163
  trace_id=node.state.meta.trace_id,
@@ -164,19 +166,20 @@ class WorkflowRunner(Generic[StateType]):
164
166
  node_definition=node.__class__,
165
167
  inputs=node._inputs,
166
168
  ),
167
- parent=WorkflowParentContext(
168
- span_id=span_id,
169
- workflow_definition=self.workflow.__class__,
170
- parent=self._parent_context,
171
- type="WORKFLOW",
172
- ),
169
+ parent=parent_context,
173
170
  )
174
171
  )
175
172
 
176
173
  logger.debug(f"Started running node: {node.__class__.__name__}")
177
174
 
178
175
  try:
179
- node_run_response = node.run()
176
+ updated_parent_context = NodeParentContext(
177
+ span_id=span_id,
178
+ node_definition=node.__class__,
179
+ parent=parent_context,
180
+ )
181
+ with execution_context(parent_context=updated_parent_context):
182
+ node_run_response = node.run()
180
183
  ports = node.Ports()
181
184
  if not isinstance(node_run_response, (BaseOutputs, Iterator)):
182
185
  raise NodeException(
@@ -197,6 +200,7 @@ class WorkflowRunner(Generic[StateType]):
197
200
  outputs = node.Outputs()
198
201
 
199
202
  def initiate_node_streaming_output(output: BaseOutput) -> None:
203
+ parent_context = get_parent_context()
200
204
  streaming_output_queues[output.name] = Queue()
201
205
  output_descriptor = OutputReference(
202
206
  name=output.name,
@@ -216,60 +220,49 @@ class WorkflowRunner(Generic[StateType]):
216
220
  output=initiated_output,
217
221
  invoked_ports=initiated_ports,
218
222
  ),
219
- parent=WorkflowParentContext(
220
- span_id=span_id,
221
- workflow_definition=self.workflow.__class__,
222
- parent=self._parent_context,
223
- ),
223
+ parent=parent_context,
224
224
  ),
225
225
  )
226
226
 
227
- for output in node_run_response:
228
- invoked_ports = output > ports
229
- if output.is_initiated:
230
- initiate_node_streaming_output(output)
231
- elif output.is_streaming:
232
- if output.name not in streaming_output_queues:
227
+ with execution_context(parent_context=updated_parent_context):
228
+ for output in node_run_response:
229
+ invoked_ports = output > ports
230
+ if output.is_initiated:
233
231
  initiate_node_streaming_output(output)
234
-
235
- streaming_output_queues[output.name].put(output.delta)
236
- self._workflow_event_inner_queue.put(
237
- NodeExecutionStreamingEvent(
238
- trace_id=node.state.meta.trace_id,
239
- span_id=span_id,
240
- body=NodeExecutionStreamingBody(
241
- node_definition=node.__class__,
242
- output=output,
243
- invoked_ports=invoked_ports,
244
- ),
245
- parent=WorkflowParentContext(
232
+ elif output.is_streaming:
233
+ if output.name not in streaming_output_queues:
234
+ initiate_node_streaming_output(output)
235
+
236
+ streaming_output_queues[output.name].put(output.delta)
237
+ self._workflow_event_inner_queue.put(
238
+ NodeExecutionStreamingEvent(
239
+ trace_id=node.state.meta.trace_id,
246
240
  span_id=span_id,
247
- workflow_definition=self.workflow.__class__,
248
- parent=self._parent_context,
241
+ body=NodeExecutionStreamingBody(
242
+ node_definition=node.__class__,
243
+ output=output,
244
+ invoked_ports=invoked_ports,
245
+ ),
246
+ parent=parent_context,
249
247
  ),
250
- ),
251
- )
252
- elif output.is_fulfilled:
253
- if output.name in streaming_output_queues:
254
- streaming_output_queues[output.name].put(UNDEF)
255
-
256
- setattr(outputs, output.name, output.value)
257
- self._workflow_event_inner_queue.put(
258
- NodeExecutionStreamingEvent(
259
- trace_id=node.state.meta.trace_id,
260
- span_id=span_id,
261
- body=NodeExecutionStreamingBody(
262
- node_definition=node.__class__,
263
- output=output,
264
- invoked_ports=invoked_ports,
265
- ),
266
- parent=WorkflowParentContext(
248
+ )
249
+ elif output.is_fulfilled:
250
+ if output.name in streaming_output_queues:
251
+ streaming_output_queues[output.name].put(UNDEF)
252
+
253
+ setattr(outputs, output.name, output.value)
254
+ self._workflow_event_inner_queue.put(
255
+ NodeExecutionStreamingEvent(
256
+ trace_id=node.state.meta.trace_id,
267
257
  span_id=span_id,
268
- workflow_definition=self.workflow.__class__,
269
- parent=self._parent_context,
270
- ),
258
+ body=NodeExecutionStreamingBody(
259
+ node_definition=node.__class__,
260
+ output=output,
261
+ invoked_ports=invoked_ports,
262
+ ),
263
+ parent=parent_context,
264
+ )
271
265
  )
272
- )
273
266
 
274
267
  invoked_ports = ports(outputs, node.state)
275
268
  node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
@@ -291,11 +284,7 @@ class WorkflowRunner(Generic[StateType]):
291
284
  outputs=outputs,
292
285
  invoked_ports=invoked_ports,
293
286
  ),
294
- parent=WorkflowParentContext(
295
- span_id=span_id,
296
- workflow_definition=self.workflow.__class__,
297
- parent=self._parent_context,
298
- ),
287
+ parent=parent_context,
299
288
  )
300
289
  )
301
290
  except NodeException as e:
@@ -328,16 +317,19 @@ class WorkflowRunner(Generic[StateType]):
328
317
  code=VellumErrorCode.INTERNAL_ERROR,
329
318
  ),
330
319
  ),
331
- parent=WorkflowParentContext(
332
- span_id=span_id,
333
- workflow_definition=self.workflow.__class__,
334
- parent=self._parent_context,
335
- ),
320
+ parent=parent_context,
336
321
  ),
337
322
  )
338
323
 
339
324
  logger.debug(f"Finished running node: {node.__class__.__name__}")
340
325
 
326
+ def _context_run_work_item(self, node: BaseNode[StateType], span_id: UUID, parent_context=None) -> None:
327
+ if parent_context is None:
328
+ parent_context = get_parent_context() or self._parent_context
329
+
330
+ with execution_context(parent_context=parent_context):
331
+ self._run_work_item(node, span_id)
332
+
341
333
  def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
342
334
  if not ports:
343
335
  return
@@ -372,13 +364,14 @@ class WorkflowRunner(Generic[StateType]):
372
364
  if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
373
365
  return
374
366
 
367
+ current_parent = get_parent_context()
375
368
  node = node_class(state=state, context=self.workflow.context)
376
369
  state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
377
370
  self._active_nodes_by_execution_id[node_span_id] = node
378
371
 
379
372
  worker_thread = Thread(
380
- target=self._run_work_item,
381
- kwargs={"node": node, "span_id": node_span_id},
373
+ target=self._context_run_work_item,
374
+ kwargs={"node": node, "span_id": node_span_id, "parent_context": current_parent},
382
375
  )
383
376
  worker_thread.start()
384
377
 
@@ -504,9 +497,16 @@ class WorkflowRunner(Generic[StateType]):
504
497
  for edge in self.workflow.get_edges():
505
498
  self._dependencies[edge.to_node].add(edge.from_port.node_class)
506
499
 
500
+ current_parent = WorkflowParentContext(
501
+ span_id=self._initial_state.meta.span_id,
502
+ workflow_definition=self.workflow.__class__,
503
+ parent=self._parent_context,
504
+ type="WORKFLOW",
505
+ )
507
506
  for node_cls in self._entrypoints:
508
507
  try:
509
- self._run_node_if_ready(self._initial_state, node_cls)
508
+ with execution_context(parent_context=current_parent):
509
+ self._run_node_if_ready(self._initial_state, node_cls)
510
510
  except NodeException as e:
511
511
  self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error))
512
512
  return
@@ -530,7 +530,9 @@ class WorkflowRunner(Generic[StateType]):
530
530
 
531
531
  self._workflow_event_outer_queue.put(event)
532
532
 
533
- rejection_error = self._handle_work_item_event(event)
533
+ with execution_context(parent_context=current_parent):
534
+ rejection_error = self._handle_work_item_event(event)
535
+
534
536
  if rejection_error:
535
537
  break
536
538
 
@@ -539,7 +541,9 @@ class WorkflowRunner(Generic[StateType]):
539
541
  while event := self._workflow_event_inner_queue.get_nowait():
540
542
  self._workflow_event_outer_queue.put(event)
541
543
 
542
- rejection_error = self._handle_work_item_event(event)
544
+ with execution_context(parent_context=current_parent):
545
+ rejection_error = self._handle_work_item_event(event)
546
+
543
547
  if rejection_error:
544
548
  break
545
549
  except Empty:
@@ -46,7 +46,10 @@ def root_workflow_event_filter(workflow_definition: Type["BaseWorkflow"], event:
46
46
  if event.parent.type != "WORKFLOW":
47
47
  return False
48
48
 
49
- return event.parent.workflow_definition == CodeResourceDefinition.encode(workflow_definition)
49
+ event_parent_definition = event.parent.workflow_definition
50
+ current_workflow_definition = CodeResourceDefinition.encode(workflow_definition)
51
+
52
+ return event_parent_definition.model_dump() == current_workflow_definition.model_dump()
50
53
 
51
54
 
52
55
  def all_workflow_event_filter(workflow_definition: Type["BaseWorkflow"], event: "WorkflowEvent") -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 0.11.0
3
+ Version: 0.11.3
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0