vellum-ai 0.11.0__py3-none-any.whl → 0.11.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. vellum/__init__.py +18 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +30 -0
  4. vellum/client/types/document_document_to_document_index.py +6 -0
  5. vellum/client/types/slim_document.py +2 -2
  6. vellum/client/types/slim_document_document_to_document_index.py +43 -0
  7. vellum/client/types/test_suite_run_exec_config.py +7 -1
  8. vellum/client/types/test_suite_run_exec_config_request.py +8 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config.py +31 -0
  10. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_data.py +27 -0
  11. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_data_request.py +27 -0
  12. vellum/client/types/test_suite_run_prompt_sandbox_history_item_exec_config_request.py +31 -0
  13. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config.py +31 -0
  14. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_data.py +27 -0
  15. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_data_request.py +27 -0
  16. vellum/client/types/test_suite_run_workflow_sandbox_history_item_exec_config_request.py +31 -0
  17. vellum/types/slim_document_document_to_document_index.py +3 -0
  18. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config.py +3 -0
  19. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_data.py +3 -0
  20. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_data_request.py +3 -0
  21. vellum/types/test_suite_run_prompt_sandbox_history_item_exec_config_request.py +3 -0
  22. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config.py +3 -0
  23. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_data.py +3 -0
  24. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_data_request.py +3 -0
  25. vellum/types/test_suite_run_workflow_sandbox_history_item_exec_config_request.py +3 -0
  26. vellum/workflows/context.py +42 -0
  27. vellum/workflows/descriptors/base.py +1 -1
  28. vellum/workflows/nodes/bases/tests/test_base_node.py +1 -1
  29. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +13 -7
  30. vellum/workflows/nodes/core/map_node/node.py +27 -4
  31. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  32. vellum/workflows/nodes/displayable/api_node/node.py +3 -2
  33. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +9 -0
  34. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +10 -1
  35. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +10 -1
  36. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +1 -1
  37. vellum/workflows/runner/runner.py +74 -70
  38. vellum/workflows/workflows/event_filters.py +4 -1
  39. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/METADATA +1 -1
  40. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/RECORD +54 -35
  41. vellum_cli/pull.py +3 -1
  42. vellum_cli/tests/test_pull.py +18 -0
  43. vellum_ee/workflows/display/base.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/api_node.py +53 -54
  45. vellum_ee/workflows/display/nodes/vellum/merge_node.py +20 -1
  46. vellum_ee/workflows/display/nodes/vellum/templating_node.py +15 -4
  47. vellum_ee/workflows/display/nodes/vellum/utils.py +26 -6
  48. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +29 -1
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +2 -2
  50. vellum_ee/workflows/display/vellum.py +4 -4
  51. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +18 -9
  52. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/LICENSE +0 -0
  53. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/WHEEL +0 -0
  54. {vellum_ai-0.11.0.dist-info → vellum_ai-0.11.3.dist-info}/entry_points.txt +0 -0
@@ -3,14 +3,18 @@ from queue import Empty, Queue
3
3
  from threading import Thread
4
4
  from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, overload
5
5
 
6
+ from vellum.workflows.context import execution_context, get_parent_context
6
7
  from vellum.workflows.descriptors.base import BaseDescriptor
7
8
  from vellum.workflows.errors.types import VellumErrorCode
9
+ from vellum.workflows.events.types import ParentContext
8
10
  from vellum.workflows.exceptions import NodeException
9
11
  from vellum.workflows.inputs.base import BaseInputs
10
12
  from vellum.workflows.nodes.bases import BaseNode
11
13
  from vellum.workflows.outputs import BaseOutputs
12
14
  from vellum.workflows.state.base import BaseState
15
+ from vellum.workflows.state.context import WorkflowContext
13
16
  from vellum.workflows.types.generics import NodeType, StateType
17
+ from vellum.workflows.workflows.event_filters import all_workflow_event_filter
14
18
 
15
19
  if TYPE_CHECKING:
16
20
  from vellum.workflows import BaseWorkflow
@@ -53,7 +57,15 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
53
57
  fulfilled_iterations: List[bool] = []
54
58
  for index, item in enumerate(self.items):
55
59
  fulfilled_iterations.append(False)
56
- thread = Thread(target=self._run_subworkflow, kwargs={"item": item, "index": index})
60
+ parent_context = get_parent_context() or self._context.parent_context
61
+ thread = Thread(
62
+ target=self._context_run_subworkflow,
63
+ kwargs={
64
+ "item": item,
65
+ "index": index,
66
+ "parent_context": parent_context,
67
+ },
68
+ )
57
69
  thread.start()
58
70
 
59
71
  try:
@@ -62,6 +74,7 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
62
74
  while map_node_event := self._event_queue.get():
63
75
  index = map_node_event[0]
64
76
  terminal_event = map_node_event[1]
77
+ self._context._emit_subworkflow_event(terminal_event)
65
78
 
66
79
  if terminal_event.name == "workflow.execution.fulfilled":
67
80
  workflow_output_vars = vars(terminal_event.outputs)
@@ -85,12 +98,22 @@ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
85
98
  )
86
99
  except Empty:
87
100
  pass
88
-
89
101
  return self.Outputs(**mapped_items)
90
102
 
103
+ def _context_run_subworkflow(
104
+ self, *, item: MapNodeItemType, index: int, parent_context: Optional[ParentContext] = None
105
+ ) -> None:
106
+ parent_context = parent_context or self._context.parent_context
107
+ with execution_context(parent_context=parent_context):
108
+ self._run_subworkflow(item=item, index=index)
109
+
91
110
  def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
92
- subworkflow = self.subworkflow(parent_state=self.state, context=self._context)
93
- events = subworkflow.stream(inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items))
111
+ context = WorkflowContext(_vellum_client=self._context._vellum_client)
112
+ subworkflow = self.subworkflow(parent_state=self.state, context=context)
113
+ events = subworkflow.stream(
114
+ inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items),
115
+ event_filter=all_workflow_event_filter,
116
+ )
94
117
 
95
118
  for event in events:
96
119
  self._event_queue.put((index, event))
@@ -1,5 +1,6 @@
1
1
  import json
2
2
 
3
+ from vellum.workflows.nodes.bases.base import BaseNode
3
4
  from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
4
5
 
5
6
 
@@ -19,3 +20,23 @@ def test_templating_node__dict_output():
19
20
 
20
21
  # THEN the output is json serializable
21
22
  assert json.loads(outputs.result) == {"key": "value"}
23
+
24
+
25
+ def test_templating_node__execution_count_reference():
26
+ # GIVEN a random node
27
+ class OtherNode(BaseNode):
28
+ pass
29
+
30
+ # AND a templating node that references the execution count of the random node
31
+ class TemplateNode(TemplatingNode):
32
+ template = "{{ total }}"
33
+ inputs = {
34
+ "total": OtherNode.Execution.count,
35
+ }
36
+
37
+ # WHEN the node is run
38
+ node = TemplateNode()
39
+ outputs = node.run()
40
+
41
+ # THEN the output is just the total
42
+ assert outputs.result == "0"
@@ -18,13 +18,14 @@ class APINode(BaseAPINode):
18
18
 
19
19
  authorization_type: Optional[AuthorizationType] = None - The type of authorization to use for the API call.
20
20
  api_key_header_key: Optional[str] = None - The header key to use for the API key authorization.
21
- bearer_token_value: Optional[str] = None - The bearer token value to use for the bearer token authorization.
21
+ bearer_token_value: Optional[Union[str, VellumSecretReference]] = None - The bearer token value to use
22
+ for the bearer token authorization.
22
23
  """
23
24
 
24
25
  authorization_type: Optional[AuthorizationType] = None
25
26
  api_key_header_key: Optional[str] = None
26
27
  api_key_header_value: Optional[Union[str, VellumSecretReference]] = None
27
- bearer_token_value: Optional[str] = None
28
+ bearer_token_value: Optional[Union[str, VellumSecretReference]] = None
28
29
 
29
30
  def run(self) -> BaseAPINode.Outputs:
30
31
  headers = self.headers or {}
@@ -14,7 +14,9 @@ from vellum import (
14
14
  PromptRequestStringInput,
15
15
  VellumVariable,
16
16
  )
17
+ from vellum.client import RequestOptions
17
18
  from vellum.workflows.constants import OMIT
19
+ from vellum.workflows.context import get_parent_context
18
20
  from vellum.workflows.errors import VellumErrorCode
19
21
  from vellum.workflows.exceptions import NodeException
20
22
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
@@ -48,6 +50,13 @@ class BaseInlinePromptNode(BasePromptNode, Generic[StateType]):
48
50
 
49
51
  def _get_prompt_event_stream(self) -> Iterator[AdHocExecutePromptEvent]:
50
52
  input_variables, input_values = self._compile_prompt_inputs()
53
+ current_parent_context = get_parent_context()
54
+ parent_context = current_parent_context.model_dump_json() if current_parent_context else None
55
+ request_options = self.request_options or RequestOptions()
56
+ request_options["additional_body_parameters"] = {
57
+ "execution_context": {"parent_context": parent_context},
58
+ **request_options.get("additional_body_parameters", {}),
59
+ }
51
60
 
52
61
  return self._context.vellum_client.ad_hoc.adhoc_execute_prompt_stream(
53
62
  ml_model=self.ml_model,
@@ -11,7 +11,9 @@ from vellum import (
11
11
  RawPromptExecutionOverridesRequest,
12
12
  StringInputRequest,
13
13
  )
14
+ from vellum.client import RequestOptions
14
15
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
16
+ from vellum.workflows.context import get_parent_context
15
17
  from vellum.workflows.errors import VellumErrorCode
16
18
  from vellum.workflows.exceptions import NodeException
17
19
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
@@ -46,6 +48,13 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
46
48
  metadata: Optional[Dict[str, Optional[Any]]] = OMIT
47
49
 
48
50
  def _get_prompt_event_stream(self) -> Iterator[ExecutePromptEvent]:
51
+ current_parent_context = get_parent_context()
52
+ parent_context = current_parent_context.model_dump() if current_parent_context else None
53
+ request_options = self.request_options or RequestOptions()
54
+ request_options["additional_body_parameters"] = {
55
+ "execution_context": {"parent_context": parent_context},
56
+ **request_options.get("additional_body_parameters", {}),
57
+ }
49
58
  return self._context.vellum_client.execute_prompt_stream(
50
59
  inputs=self._compile_prompt_inputs(),
51
60
  prompt_deployment_id=str(self.deployment) if isinstance(self.deployment, UUID) else None,
@@ -56,7 +65,7 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
56
65
  raw_overrides=self.raw_overrides,
57
66
  expand_raw=self.expand_raw,
58
67
  metadata=self.metadata,
59
- request_options=self.request_options,
68
+ request_options=request_options,
60
69
  )
61
70
 
62
71
  def _compile_prompt_inputs(self) -> List[PromptDeploymentInputRequest]:
@@ -13,6 +13,7 @@ from vellum import (
13
13
  )
14
14
  from vellum.core import RequestOptions
15
15
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
16
+ from vellum.workflows.context import get_parent_context
16
17
  from vellum.workflows.errors import VellumErrorCode
17
18
  from vellum.workflows.exceptions import NodeException
18
19
  from vellum.workflows.nodes.bases.base_subworkflow_node.node import BaseSubworkflowNode
@@ -89,6 +90,13 @@ class SubworkflowDeploymentNode(BaseSubworkflowNode[StateType], Generic[StateTyp
89
90
  return compiled_inputs
90
91
 
91
92
  def run(self) -> Iterator[BaseOutput]:
93
+ current_parent_context = get_parent_context()
94
+ parent_context = current_parent_context.model_dump(mode="json") if current_parent_context else None
95
+ request_options = self.request_options or RequestOptions()
96
+ request_options["additional_body_parameters"] = {
97
+ "execution_context": {"parent_context": parent_context},
98
+ **request_options.get("additional_body_parameters", {}),
99
+ }
92
100
  subworkflow_stream = self._context.vellum_client.execute_workflow_stream(
93
101
  inputs=self._compile_subworkflow_inputs(),
94
102
  workflow_deployment_id=str(self.deployment) if isinstance(self.deployment, UUID) else None,
@@ -97,8 +105,9 @@ class SubworkflowDeploymentNode(BaseSubworkflowNode[StateType], Generic[StateTyp
97
105
  external_id=self.external_id,
98
106
  event_types=["WORKFLOW"],
99
107
  metadata=self.metadata,
100
- request_options=self.request_options,
108
+ request_options=request_options,
101
109
  )
110
+ # for some reason execution context isn't showing as an option? ^ failing mypy
102
111
 
103
112
  outputs: Optional[List[WorkflowOutput]] = None
104
113
  fulfilled_output_names: Set[str] = set()
@@ -75,5 +75,5 @@ def test_text_prompt_deployment_node__basic(vellum_client):
75
75
  prompt_deployment_name="my-deployment",
76
76
  raw_overrides=OMIT,
77
77
  release_tag="LATEST",
78
- request_options=None,
78
+ request_options={"additional_body_parameters": {"execution_context": {"parent_context": None}}},
79
79
  )
@@ -7,6 +7,7 @@ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import UNDEF
10
+ from vellum.workflows.context import execution_context, get_parent_context
10
11
  from vellum.workflows.descriptors.base import BaseDescriptor
11
12
  from vellum.workflows.edges.edge import Edge
12
13
  from vellum.workflows.errors import VellumError, VellumErrorCode
@@ -28,7 +29,7 @@ from vellum.workflows.events.node import (
28
29
  NodeExecutionRejectedBody,
29
30
  NodeExecutionStreamingBody,
30
31
  )
31
- from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
32
+ from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, WorkflowParentContext
32
33
  from vellum.workflows.events.workflow import (
33
34
  WorkflowExecutionFulfilledBody,
34
35
  WorkflowExecutionInitiatedBody,
@@ -125,7 +126,7 @@ class WorkflowRunner(Generic[StateType]):
125
126
 
126
127
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
127
128
  self._cancel_signal = cancel_signal
128
- self._parent_context = parent_context
129
+ self._parent_context = get_parent_context() or parent_context
129
130
 
130
131
  setattr(
131
132
  self._initial_state,
@@ -156,6 +157,7 @@ class WorkflowRunner(Generic[StateType]):
156
157
  return event
157
158
 
158
159
  def _run_work_item(self, node: BaseNode[StateType], span_id: UUID) -> None:
160
+ parent_context = get_parent_context()
159
161
  self._workflow_event_inner_queue.put(
160
162
  NodeExecutionInitiatedEvent(
161
163
  trace_id=node.state.meta.trace_id,
@@ -164,19 +166,20 @@ class WorkflowRunner(Generic[StateType]):
164
166
  node_definition=node.__class__,
165
167
  inputs=node._inputs,
166
168
  ),
167
- parent=WorkflowParentContext(
168
- span_id=span_id,
169
- workflow_definition=self.workflow.__class__,
170
- parent=self._parent_context,
171
- type="WORKFLOW",
172
- ),
169
+ parent=parent_context,
173
170
  )
174
171
  )
175
172
 
176
173
  logger.debug(f"Started running node: {node.__class__.__name__}")
177
174
 
178
175
  try:
179
- node_run_response = node.run()
176
+ updated_parent_context = NodeParentContext(
177
+ span_id=span_id,
178
+ node_definition=node.__class__,
179
+ parent=parent_context,
180
+ )
181
+ with execution_context(parent_context=updated_parent_context):
182
+ node_run_response = node.run()
180
183
  ports = node.Ports()
181
184
  if not isinstance(node_run_response, (BaseOutputs, Iterator)):
182
185
  raise NodeException(
@@ -197,6 +200,7 @@ class WorkflowRunner(Generic[StateType]):
197
200
  outputs = node.Outputs()
198
201
 
199
202
  def initiate_node_streaming_output(output: BaseOutput) -> None:
203
+ parent_context = get_parent_context()
200
204
  streaming_output_queues[output.name] = Queue()
201
205
  output_descriptor = OutputReference(
202
206
  name=output.name,
@@ -216,60 +220,49 @@ class WorkflowRunner(Generic[StateType]):
216
220
  output=initiated_output,
217
221
  invoked_ports=initiated_ports,
218
222
  ),
219
- parent=WorkflowParentContext(
220
- span_id=span_id,
221
- workflow_definition=self.workflow.__class__,
222
- parent=self._parent_context,
223
- ),
223
+ parent=parent_context,
224
224
  ),
225
225
  )
226
226
 
227
- for output in node_run_response:
228
- invoked_ports = output > ports
229
- if output.is_initiated:
230
- initiate_node_streaming_output(output)
231
- elif output.is_streaming:
232
- if output.name not in streaming_output_queues:
227
+ with execution_context(parent_context=updated_parent_context):
228
+ for output in node_run_response:
229
+ invoked_ports = output > ports
230
+ if output.is_initiated:
233
231
  initiate_node_streaming_output(output)
234
-
235
- streaming_output_queues[output.name].put(output.delta)
236
- self._workflow_event_inner_queue.put(
237
- NodeExecutionStreamingEvent(
238
- trace_id=node.state.meta.trace_id,
239
- span_id=span_id,
240
- body=NodeExecutionStreamingBody(
241
- node_definition=node.__class__,
242
- output=output,
243
- invoked_ports=invoked_ports,
244
- ),
245
- parent=WorkflowParentContext(
232
+ elif output.is_streaming:
233
+ if output.name not in streaming_output_queues:
234
+ initiate_node_streaming_output(output)
235
+
236
+ streaming_output_queues[output.name].put(output.delta)
237
+ self._workflow_event_inner_queue.put(
238
+ NodeExecutionStreamingEvent(
239
+ trace_id=node.state.meta.trace_id,
246
240
  span_id=span_id,
247
- workflow_definition=self.workflow.__class__,
248
- parent=self._parent_context,
241
+ body=NodeExecutionStreamingBody(
242
+ node_definition=node.__class__,
243
+ output=output,
244
+ invoked_ports=invoked_ports,
245
+ ),
246
+ parent=parent_context,
249
247
  ),
250
- ),
251
- )
252
- elif output.is_fulfilled:
253
- if output.name in streaming_output_queues:
254
- streaming_output_queues[output.name].put(UNDEF)
255
-
256
- setattr(outputs, output.name, output.value)
257
- self._workflow_event_inner_queue.put(
258
- NodeExecutionStreamingEvent(
259
- trace_id=node.state.meta.trace_id,
260
- span_id=span_id,
261
- body=NodeExecutionStreamingBody(
262
- node_definition=node.__class__,
263
- output=output,
264
- invoked_ports=invoked_ports,
265
- ),
266
- parent=WorkflowParentContext(
248
+ )
249
+ elif output.is_fulfilled:
250
+ if output.name in streaming_output_queues:
251
+ streaming_output_queues[output.name].put(UNDEF)
252
+
253
+ setattr(outputs, output.name, output.value)
254
+ self._workflow_event_inner_queue.put(
255
+ NodeExecutionStreamingEvent(
256
+ trace_id=node.state.meta.trace_id,
267
257
  span_id=span_id,
268
- workflow_definition=self.workflow.__class__,
269
- parent=self._parent_context,
270
- ),
258
+ body=NodeExecutionStreamingBody(
259
+ node_definition=node.__class__,
260
+ output=output,
261
+ invoked_ports=invoked_ports,
262
+ ),
263
+ parent=parent_context,
264
+ )
271
265
  )
272
- )
273
266
 
274
267
  invoked_ports = ports(outputs, node.state)
275
268
  node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
@@ -291,11 +284,7 @@ class WorkflowRunner(Generic[StateType]):
291
284
  outputs=outputs,
292
285
  invoked_ports=invoked_ports,
293
286
  ),
294
- parent=WorkflowParentContext(
295
- span_id=span_id,
296
- workflow_definition=self.workflow.__class__,
297
- parent=self._parent_context,
298
- ),
287
+ parent=parent_context,
299
288
  )
300
289
  )
301
290
  except NodeException as e:
@@ -328,16 +317,19 @@ class WorkflowRunner(Generic[StateType]):
328
317
  code=VellumErrorCode.INTERNAL_ERROR,
329
318
  ),
330
319
  ),
331
- parent=WorkflowParentContext(
332
- span_id=span_id,
333
- workflow_definition=self.workflow.__class__,
334
- parent=self._parent_context,
335
- ),
320
+ parent=parent_context,
336
321
  ),
337
322
  )
338
323
 
339
324
  logger.debug(f"Finished running node: {node.__class__.__name__}")
340
325
 
326
+ def _context_run_work_item(self, node: BaseNode[StateType], span_id: UUID, parent_context=None) -> None:
327
+ if parent_context is None:
328
+ parent_context = get_parent_context() or self._parent_context
329
+
330
+ with execution_context(parent_context=parent_context):
331
+ self._run_work_item(node, span_id)
332
+
341
333
  def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
342
334
  if not ports:
343
335
  return
@@ -372,13 +364,14 @@ class WorkflowRunner(Generic[StateType]):
372
364
  if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
373
365
  return
374
366
 
367
+ current_parent = get_parent_context()
375
368
  node = node_class(state=state, context=self.workflow.context)
376
369
  state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
377
370
  self._active_nodes_by_execution_id[node_span_id] = node
378
371
 
379
372
  worker_thread = Thread(
380
- target=self._run_work_item,
381
- kwargs={"node": node, "span_id": node_span_id},
373
+ target=self._context_run_work_item,
374
+ kwargs={"node": node, "span_id": node_span_id, "parent_context": current_parent},
382
375
  )
383
376
  worker_thread.start()
384
377
 
@@ -504,9 +497,16 @@ class WorkflowRunner(Generic[StateType]):
504
497
  for edge in self.workflow.get_edges():
505
498
  self._dependencies[edge.to_node].add(edge.from_port.node_class)
506
499
 
500
+ current_parent = WorkflowParentContext(
501
+ span_id=self._initial_state.meta.span_id,
502
+ workflow_definition=self.workflow.__class__,
503
+ parent=self._parent_context,
504
+ type="WORKFLOW",
505
+ )
507
506
  for node_cls in self._entrypoints:
508
507
  try:
509
- self._run_node_if_ready(self._initial_state, node_cls)
508
+ with execution_context(parent_context=current_parent):
509
+ self._run_node_if_ready(self._initial_state, node_cls)
510
510
  except NodeException as e:
511
511
  self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error))
512
512
  return
@@ -530,7 +530,9 @@ class WorkflowRunner(Generic[StateType]):
530
530
 
531
531
  self._workflow_event_outer_queue.put(event)
532
532
 
533
- rejection_error = self._handle_work_item_event(event)
533
+ with execution_context(parent_context=current_parent):
534
+ rejection_error = self._handle_work_item_event(event)
535
+
534
536
  if rejection_error:
535
537
  break
536
538
 
@@ -539,7 +541,9 @@ class WorkflowRunner(Generic[StateType]):
539
541
  while event := self._workflow_event_inner_queue.get_nowait():
540
542
  self._workflow_event_outer_queue.put(event)
541
543
 
542
- rejection_error = self._handle_work_item_event(event)
544
+ with execution_context(parent_context=current_parent):
545
+ rejection_error = self._handle_work_item_event(event)
546
+
543
547
  if rejection_error:
544
548
  break
545
549
  except Empty:
@@ -46,7 +46,10 @@ def root_workflow_event_filter(workflow_definition: Type["BaseWorkflow"], event:
46
46
  if event.parent.type != "WORKFLOW":
47
47
  return False
48
48
 
49
- return event.parent.workflow_definition == CodeResourceDefinition.encode(workflow_definition)
49
+ event_parent_definition = event.parent.workflow_definition
50
+ current_workflow_definition = CodeResourceDefinition.encode(workflow_definition)
51
+
52
+ return event_parent_definition.model_dump() == current_workflow_definition.model_dump()
50
53
 
51
54
 
52
55
  def all_workflow_event_filter(workflow_definition: Type["BaseWorkflow"], event: "WorkflowEvent") -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 0.11.0
3
+ Version: 0.11.3
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0