vellum-ai 0.14.5__py3-none-any.whl → 0.14.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. vellum/__init__.py +18 -0
  2. vellum/client/__init__.py +8 -8
  3. vellum/client/core/client_wrapper.py +1 -1
  4. vellum/client/resources/__init__.py +2 -0
  5. vellum/client/resources/workflow_sandboxes/__init__.py +3 -0
  6. vellum/client/resources/workflow_sandboxes/client.py +146 -0
  7. vellum/client/resources/workflow_sandboxes/types/__init__.py +5 -0
  8. vellum/client/resources/workflow_sandboxes/types/list_workflow_sandbox_examples_request_tag.py +5 -0
  9. vellum/client/types/__init__.py +16 -0
  10. vellum/client/types/array_chat_message_content_item.py +6 -1
  11. vellum/client/types/array_chat_message_content_item_request.py +2 -0
  12. vellum/client/types/chat_message_content.py +2 -0
  13. vellum/client/types/chat_message_content_request.py +2 -0
  14. vellum/client/types/document_chat_message_content.py +25 -0
  15. vellum/client/types/document_chat_message_content_request.py +25 -0
  16. vellum/client/types/document_vellum_value.py +25 -0
  17. vellum/client/types/document_vellum_value_request.py +25 -0
  18. vellum/client/types/paginated_workflow_sandbox_example_list.py +23 -0
  19. vellum/client/types/vellum_document.py +20 -0
  20. vellum/client/types/vellum_document_request.py +20 -0
  21. vellum/client/types/vellum_value.py +2 -0
  22. vellum/client/types/vellum_value_request.py +2 -0
  23. vellum/client/types/vellum_variable_type.py +1 -0
  24. vellum/client/types/workflow_sandbox_example.py +22 -0
  25. vellum/resources/workflow_sandboxes/types/__init__.py +3 -0
  26. vellum/resources/workflow_sandboxes/types/list_workflow_sandbox_examples_request_tag.py +3 -0
  27. vellum/types/document_chat_message_content.py +3 -0
  28. vellum/types/document_chat_message_content_request.py +3 -0
  29. vellum/types/document_vellum_value.py +3 -0
  30. vellum/types/document_vellum_value_request.py +3 -0
  31. vellum/types/paginated_workflow_sandbox_example_list.py +3 -0
  32. vellum/types/vellum_document.py +3 -0
  33. vellum/types/vellum_document_request.py +3 -0
  34. vellum/types/workflow_sandbox_example.py +3 -0
  35. vellum/workflows/exceptions.py +18 -0
  36. vellum/workflows/inputs/base.py +27 -1
  37. vellum/workflows/inputs/tests/__init__.py +0 -0
  38. vellum/workflows/inputs/tests/test_inputs.py +49 -0
  39. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -1
  40. vellum/workflows/nodes/core/map_node/node.py +7 -7
  41. vellum/workflows/nodes/core/try_node/node.py +1 -1
  42. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +2 -2
  43. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
  44. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +5 -4
  45. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +4 -4
  46. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +49 -15
  47. vellum/workflows/nodes/displayable/subworkflow_deployment_node/tests/test_node.py +165 -0
  48. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +3 -1
  49. vellum/workflows/outputs/base.py +1 -1
  50. vellum/workflows/runner/runner.py +16 -10
  51. vellum/workflows/state/context.py +7 -7
  52. vellum/workflows/workflows/base.py +61 -59
  53. vellum/workflows/workflows/tests/test_base_workflow.py +131 -40
  54. {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/METADATA +1 -1
  55. {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/RECORD +68 -44
  56. vellum_cli/__init__.py +36 -0
  57. vellum_cli/init.py +128 -0
  58. vellum_cli/pull.py +6 -3
  59. vellum_cli/tests/test_init.py +355 -0
  60. vellum_cli/tests/test_pull.py +127 -0
  61. vellum_ee/workflows/display/nodes/base_node_display.py +4 -4
  62. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +31 -0
  63. vellum_ee/workflows/display/nodes/vellum/utils.py +8 -0
  64. vellum_ee/workflows/display/vellum.py +0 -4
  65. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +29 -0
  66. {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/LICENSE +0 -0
  67. {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/WHEEL +0 -0
  68. {vellum_ai-0.14.5.dist-info → vellum_ai-0.14.7.dist-info}/entry_points.txt +0 -0
@@ -3,6 +3,7 @@ from datetime import datetime
3
3
  from uuid import uuid4
4
4
  from typing import Any, Iterator, List
5
5
 
6
+ from vellum.client.core.api_error import ApiError
6
7
  from vellum.client.types.chat_message import ChatMessage
7
8
  from vellum.client.types.chat_message_request import ChatMessageRequest
8
9
  from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent
@@ -11,6 +12,8 @@ from vellum.client.types.workflow_request_chat_history_input_request import Work
11
12
  from vellum.client.types.workflow_request_json_input_request import WorkflowRequestJsonInputRequest
12
13
  from vellum.client.types.workflow_result_event import WorkflowResultEvent
13
14
  from vellum.client.types.workflow_stream_event import WorkflowStreamEvent
15
+ from vellum.workflows.errors import WorkflowErrorCode
16
+ from vellum.workflows.exceptions import NodeException
14
17
  from vellum.workflows.nodes.displayable.subworkflow_deployment_node.node import SubworkflowDeploymentNode
15
18
 
16
19
 
@@ -129,3 +132,165 @@ def test_run_workflow__any_array(vellum_client):
129
132
  assert call_kwargs["inputs"] == [
130
133
  WorkflowRequestJsonInputRequest(name="fruits", value=["apple", "banana", "cherry"]),
131
134
  ]
135
+
136
+
137
+ def test_run_workflow__no_deployment():
138
+ """Confirm that we raise error when running a subworkflow deployment node with no deployment attribute set"""
139
+
140
+ # GIVEN a Subworkflow Deployment Node
141
+ class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
142
+ subworkflow_inputs = {
143
+ "fruits": ["apple", "banana", "cherry"],
144
+ }
145
+
146
+ # WHEN/THEN running the node should raise a NodeException
147
+ node = ExampleSubworkflowDeploymentNode()
148
+ with pytest.raises(NodeException) as exc_info:
149
+ list(node.run())
150
+
151
+ # AND the error message should be correct
152
+ assert exc_info.value.code == WorkflowErrorCode.NODE_EXECUTION
153
+ assert "Expected subworkflow deployment attribute to be either a UUID or STR, got None instead" in str(
154
+ exc_info.value
155
+ )
156
+
157
+
158
+ def test_run_workflow__hyphenated_output(vellum_client):
159
+ """Confirm that we can successfully handle subworkflow outputs with hyphenated names"""
160
+
161
+ # GIVEN a Subworkflow Deployment Node
162
+ class ExampleSubworkflowDeploymentNode(SubworkflowDeploymentNode):
163
+ deployment = "example_subworkflow_deployment"
164
+ subworkflow_inputs = {
165
+ "test_input": "test_value",
166
+ }
167
+
168
+ class Outputs(SubworkflowDeploymentNode.Outputs):
169
+ final_output_copy: str
170
+
171
+ # AND we know what the Subworkflow Deployment will respond with
172
+ def generate_subworkflow_events(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
173
+ execution_id = str(uuid4())
174
+ expected_events: List[WorkflowStreamEvent] = [
175
+ WorkflowExecutionWorkflowResultEvent(
176
+ execution_id=execution_id,
177
+ data=WorkflowResultEvent(
178
+ id=str(uuid4()),
179
+ state="INITIATED",
180
+ ts=datetime.now(),
181
+ ),
182
+ ),
183
+ WorkflowExecutionWorkflowResultEvent(
184
+ execution_id=execution_id,
185
+ data=WorkflowResultEvent(
186
+ id=str(uuid4()),
187
+ state="FULFILLED",
188
+ ts=datetime.now(),
189
+ outputs=[
190
+ WorkflowOutputString(
191
+ id=str(uuid4()),
192
+ name="final-output_copy", # Note the hyphen here
193
+ value="test success",
194
+ )
195
+ ],
196
+ ),
197
+ ),
198
+ ]
199
+ yield from expected_events
200
+
201
+ vellum_client.execute_workflow_stream.side_effect = generate_subworkflow_events
202
+
203
+ # WHEN we run the node
204
+ node = ExampleSubworkflowDeploymentNode()
205
+ events = list(node.run())
206
+
207
+ # THEN the node should have completed successfully
208
+ assert events[-1].name == "final_output_copy" # Note the underscore here
209
+ assert events[-1].value == "test success"
210
+
211
+
212
+ @pytest.mark.parametrize(
213
+ ["exception", "expected_code", "expected_message"],
214
+ [
215
+ (
216
+ ApiError(status_code=400, body={"detail": "Missing required input variable: 'foo'"}),
217
+ WorkflowErrorCode.INVALID_INPUTS,
218
+ "Missing required input variable: 'foo'",
219
+ ),
220
+ (
221
+ ApiError(status_code=400, body={"message": "Missing required input variable: 'foo'"}),
222
+ WorkflowErrorCode.INVALID_INPUTS,
223
+ "Failed to execute Subworkflow Deployment",
224
+ ),
225
+ (
226
+ ApiError(status_code=400, body="Missing required input variable: 'foo'"),
227
+ WorkflowErrorCode.INTERNAL_ERROR,
228
+ "Failed to execute Subworkflow Deployment",
229
+ ),
230
+ (
231
+ ApiError(status_code=None, body={"detail": "Missing required input variable: 'foo'"}),
232
+ WorkflowErrorCode.INTERNAL_ERROR,
233
+ "Failed to execute Subworkflow Deployment",
234
+ ),
235
+ (
236
+ ApiError(status_code=500, body={"detail": "Missing required input variable: 'foo'"}),
237
+ WorkflowErrorCode.INTERNAL_ERROR,
238
+ "Failed to execute Subworkflow Deployment",
239
+ ),
240
+ ],
241
+ ids=["400", "invalid_dict", "invalid_body", "no_status_code", "500"],
242
+ )
243
+ def test_subworkflow_deployment_node__api_error__invalid_inputs_node_exception(
244
+ vellum_client, exception, expected_code, expected_message
245
+ ):
246
+ # GIVEN a prompt node with an invalid model name
247
+ class MyNode(SubworkflowDeploymentNode):
248
+ deployment = "example_subworkflow_deployment"
249
+ subworkflow_inputs = {
250
+ "not_foo": "bar",
251
+ }
252
+
253
+ # AND the Subworkflow Deployment API call fails
254
+ def _side_effect(*args: Any, **kwargs: Any) -> Iterator[WorkflowStreamEvent]:
255
+ if kwargs.get("_mock_condition_to_induce_an_error"):
256
+ yield WorkflowExecutionWorkflowResultEvent(
257
+ execution_id=str(uuid4()),
258
+ data=WorkflowResultEvent(
259
+ id=str(uuid4()),
260
+ state="INITIATED",
261
+ ts=datetime.now(),
262
+ ),
263
+ )
264
+ else:
265
+ raise exception
266
+
267
+ # AND the vellum client execute workflow stream raises a 4xx error
268
+ vellum_client.execute_workflow_stream.side_effect = _side_effect
269
+
270
+ # WHEN the node is run
271
+ with pytest.raises(NodeException) as e:
272
+ list(MyNode().run())
273
+
274
+ # THEN the node raises the correct NodeException
275
+ assert e.value.code == expected_code
276
+ assert e.value.message == expected_message
277
+
278
+
279
+ def test_subworkflow_deployment_node__immediate_api_error__node_exception(vellum_client):
280
+ # GIVEN a prompt node with an invalid model name
281
+ class MyNode(SubworkflowDeploymentNode):
282
+ deployment = "example_subworkflow_deployment"
283
+ subworkflow_inputs = {
284
+ "not_foo": "bar",
285
+ }
286
+
287
+ # AND the vellum client execute workflow stream raises a 4xx error
288
+ vellum_client.execute_workflow_stream.side_effect = ApiError(status_code=404, body={"detail": "Not found"})
289
+
290
+ # WHEN the node is run
291
+ with pytest.raises(NodeException) as e:
292
+ list(MyNode().run())
293
+
294
+ # THEN the node raises the correct NodeException
295
+ assert e.value.code == WorkflowErrorCode.INVALID_INPUTS
296
+ assert e.value.message == "Not found"
@@ -74,5 +74,7 @@ def test_text_prompt_deployment_node__basic(vellum_client):
74
74
  prompt_deployment_name="my-deployment",
75
75
  raw_overrides=OMIT,
76
76
  release_tag="LATEST",
77
- request_options={"additional_body_parameters": {"execution_context": {"parent_context": None}}},
77
+ request_options={
78
+ "additional_body_parameters": {"execution_context": {"parent_context": None, "trace_id": None}}
79
+ },
78
80
  )
@@ -32,7 +32,7 @@ class BaseOutput(Generic[_Delta, _Accumulated]):
32
32
  if value is not undefined and delta is not undefined:
33
33
  raise ValueError("Cannot set both value and delta")
34
34
 
35
- self._name = name
35
+ self._name = name.replace("-", "_") # Convert hyphens to underscores for valid python variable names
36
36
  self._value = value
37
37
  self._delta = delta
38
38
 
@@ -7,7 +7,7 @@ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import undefined
10
- from vellum.workflows.context import execution_context, get_parent_context
10
+ from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context, get_parent_context
11
11
  from vellum.workflows.descriptors.base import BaseDescriptor
12
12
  from vellum.workflows.edges.edge import Edge
13
13
  from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
@@ -29,7 +29,7 @@ from vellum.workflows.events.node import (
29
29
  NodeExecutionRejectedBody,
30
30
  NodeExecutionStreamingBody,
31
31
  )
32
- from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, WorkflowParentContext
32
+ from vellum.workflows.events.types import BaseEvent, NodeParentContext, WorkflowParentContext
33
33
  from vellum.workflows.events.workflow import (
34
34
  WorkflowExecutionFulfilledBody,
35
35
  WorkflowExecutionInitiatedBody,
@@ -75,8 +75,8 @@ class WorkflowRunner(Generic[StateType]):
75
75
  external_inputs: Optional[ExternalInputsArg] = None,
76
76
  cancel_signal: Optional[ThreadingEvent] = None,
77
77
  node_output_mocks: Optional[MockNodeExecutionArg] = None,
78
- parent_context: Optional[ParentContext] = None,
79
78
  max_concurrency: Optional[int] = None,
79
+ init_execution_context: Optional[ExecutionContext] = None,
80
80
  ):
81
81
  if state and external_inputs:
82
82
  raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
@@ -98,6 +98,11 @@ class WorkflowRunner(Generic[StateType]):
98
98
  elif external_inputs:
99
99
  self._initial_state = self.workflow.get_most_recent_state()
100
100
  for descriptor, value in external_inputs.items():
101
+ if not any(isinstance(value, type_) for type_ in descriptor.types):
102
+ raise NodeException(
103
+ f"Invalid external input type for {descriptor.name}",
104
+ code=WorkflowErrorCode.INVALID_INPUTS,
105
+ )
101
106
  self._initial_state.meta.external_inputs[descriptor] = value
102
107
 
103
108
  self._entrypoints = [
@@ -133,7 +138,8 @@ class WorkflowRunner(Generic[StateType]):
133
138
 
134
139
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
135
140
  self._cancel_signal = cancel_signal
136
- self._parent_context = get_parent_context() or parent_context
141
+ self._execution_context = init_execution_context or get_execution_context()
142
+ self._parent_context = self._execution_context.parent_context
137
143
 
138
144
  setattr(
139
145
  self._initial_state,
@@ -196,7 +202,7 @@ class WorkflowRunner(Generic[StateType]):
196
202
  break
197
203
 
198
204
  if not was_mocked:
199
- with execution_context(parent_context=updated_parent_context):
205
+ with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
200
206
  node_run_response = node.run()
201
207
 
202
208
  ports = node.Ports()
@@ -243,7 +249,7 @@ class WorkflowRunner(Generic[StateType]):
243
249
  ),
244
250
  )
245
251
 
246
- with execution_context(parent_context=updated_parent_context):
252
+ with execution_context(parent_context=updated_parent_context, trace_id=node.state.meta.trace_id):
247
253
  for output in node_run_response:
248
254
  invoked_ports = output > ports
249
255
  if output.is_initiated:
@@ -346,7 +352,7 @@ class WorkflowRunner(Generic[StateType]):
346
352
  if parent_context is None:
347
353
  parent_context = get_parent_context() or self._parent_context
348
354
 
349
- with execution_context(parent_context=parent_context):
355
+ with execution_context(parent_context=parent_context, trace_id=node.state.meta.trace_id):
350
356
  self._run_work_item(node, span_id)
351
357
 
352
358
  def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
@@ -524,7 +530,7 @@ class WorkflowRunner(Generic[StateType]):
524
530
  for node_cls in self._entrypoints:
525
531
  try:
526
532
  if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
527
- with execution_context(parent_context=current_parent):
533
+ with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
528
534
  self._run_node_if_ready(self._initial_state, node_cls)
529
535
  else:
530
536
  self._concurrency_queue.put((self._initial_state, node_cls, None))
@@ -551,7 +557,7 @@ class WorkflowRunner(Generic[StateType]):
551
557
 
552
558
  self._workflow_event_outer_queue.put(event)
553
559
 
554
- with execution_context(parent_context=current_parent):
560
+ with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
555
561
  rejection_error = self._handle_work_item_event(event)
556
562
 
557
563
  if rejection_error:
@@ -562,7 +568,7 @@ class WorkflowRunner(Generic[StateType]):
562
568
  while event := self._workflow_event_inner_queue.get_nowait():
563
569
  self._workflow_event_outer_queue.put(event)
564
570
 
565
- with execution_context(parent_context=current_parent):
571
+ with execution_context(parent_context=current_parent, trace_id=self._initial_state.meta.trace_id):
566
572
  rejection_error = self._handle_work_item_event(event)
567
573
 
568
574
  if rejection_error:
@@ -3,7 +3,7 @@ from queue import Queue
3
3
  from typing import TYPE_CHECKING, Dict, List, Optional, Type
4
4
 
5
5
  from vellum import Vellum
6
- from vellum.workflows.events.types import ParentContext
6
+ from vellum.workflows.context import ExecutionContext, get_execution_context
7
7
  from vellum.workflows.nodes.mocks import MockNodeExecution, MockNodeExecutionArg
8
8
  from vellum.workflows.outputs.base import BaseOutputs
9
9
  from vellum.workflows.references.constant import ConstantValueReference
@@ -18,12 +18,14 @@ class WorkflowContext:
18
18
  self,
19
19
  *,
20
20
  vellum_client: Optional[Vellum] = None,
21
- parent_context: Optional[ParentContext] = None,
21
+ execution_context: Optional[ExecutionContext] = None,
22
22
  ):
23
23
  self._vellum_client = vellum_client
24
- self._parent_context = parent_context
25
24
  self._event_queue: Optional[Queue["WorkflowEvent"]] = None
26
25
  self._node_output_mocks_map: Dict[Type[BaseOutputs], List[MockNodeExecution]] = {}
26
+ self._execution_context = get_execution_context()
27
+ if not self._execution_context.parent_context and execution_context:
28
+ self._execution_context = execution_context
27
29
 
28
30
  @cached_property
29
31
  def vellum_client(self) -> Vellum:
@@ -33,10 +35,8 @@ class WorkflowContext:
33
35
  return create_vellum_client()
34
36
 
35
37
  @cached_property
36
- def parent_context(self) -> Optional[ParentContext]:
37
- if self._parent_context:
38
- return self._parent_context
39
- return None
38
+ def execution_context(self) -> ExecutionContext:
39
+ return self._execution_context
40
40
 
41
41
  @cached_property
42
42
  def node_output_mocks_map(self) -> Dict[Type[BaseOutputs], List[MockNodeExecution]]:
@@ -24,6 +24,7 @@ from typing import (
24
24
  get_args,
25
25
  )
26
26
 
27
+ from vellum.workflows.context import get_execution_context
27
28
  from vellum.workflows.edges import Edge
28
29
  from vellum.workflows.emitters.base import BaseWorkflowEmitter
29
30
  from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
@@ -171,6 +172,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
171
172
  self.resolvers = resolvers or (self.resolvers if hasattr(self, "resolvers") else [])
172
173
  self._context = context or WorkflowContext()
173
174
  self._store = Store()
175
+ self._execution_context = self._context.execution_context
174
176
 
175
177
  self.validate()
176
178
 
@@ -178,48 +180,64 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
178
180
  def context(self) -> WorkflowContext:
179
181
  return self._context
180
182
 
181
- @classmethod
182
- def get_subgraphs(cls) -> List[Graph]:
183
- original_graph = cls.graph
184
- if isinstance(original_graph, Graph):
185
- return [original_graph]
186
- if isinstance(original_graph, set):
187
- return [
188
- subgraph if isinstance(subgraph, Graph) else Graph.from_node(subgraph) for subgraph in original_graph
189
- ]
190
- if issubclass(original_graph, BaseNode):
191
- return [Graph.from_node(original_graph)]
192
-
193
- raise ValueError(f"Unexpected graph type: {original_graph.__class__}")
194
-
195
- @classmethod
196
- def get_edges(cls) -> Iterator[Edge]:
183
+ @staticmethod
184
+ def _resolve_graph(graph: GraphAttribute) -> List[Graph]:
197
185
  """
198
- Returns an iterator over the edges in the workflow. We use a set to
199
- ensure uniqueness, and the iterator to preserve order.
186
+ Resolves a single graph source to a list of Graph objects.
200
187
  """
188
+ if isinstance(graph, Graph):
189
+ return [graph]
190
+ if isinstance(graph, set):
191
+ graphs = []
192
+ for item in graph:
193
+ if isinstance(item, Graph):
194
+ graphs.append(item)
195
+ elif issubclass(item, BaseNode):
196
+ graphs.append(Graph.from_node(item))
197
+ else:
198
+ raise ValueError(f"Unexpected graph type: {type(item)}")
199
+ return graphs
200
+ if issubclass(graph, BaseNode):
201
+ return [Graph.from_node(graph)]
202
+ raise ValueError(f"Unexpected graph type: {type(graph)}")
201
203
 
204
+ @staticmethod
205
+ def _get_edges_from_subgraphs(subgraphs: Iterable[Graph]) -> Iterator[Edge]:
202
206
  edges = set()
203
- subgraphs = cls.get_subgraphs()
204
207
  for subgraph in subgraphs:
205
208
  for edge in subgraph.edges:
206
209
  if edge not in edges:
207
210
  edges.add(edge)
208
211
  yield edge
209
212
 
213
+ @staticmethod
214
+ def _get_nodes_from_subgraphs(subgraphs: Iterable[Graph]) -> Iterator[Type[BaseNode]]:
215
+ nodes = set()
216
+ for subgraph in subgraphs:
217
+ for node in subgraph.nodes:
218
+ if node not in nodes:
219
+ nodes.add(node)
220
+ yield node
221
+
222
+ @classmethod
223
+ def get_subgraphs(cls) -> List[Graph]:
224
+ return cls._resolve_graph(cls.graph)
225
+
226
+ @classmethod
227
+ def get_edges(cls) -> Iterator[Edge]:
228
+ """
229
+ Returns an iterator over the edges in the workflow. We use a set to
230
+ ensure uniqueness, and the iterator to preserve order.
231
+ """
232
+ return cls._get_edges_from_subgraphs(cls.get_subgraphs())
233
+
210
234
  @classmethod
211
235
  def get_nodes(cls) -> Iterator[Type[BaseNode]]:
212
236
  """
213
237
  Returns an iterator over the nodes in the workflow. We use a set to
214
238
  ensure uniqueness, and the iterator to preserve order.
215
239
  """
216
-
217
- nodes = set()
218
- for subgraph in cls.get_subgraphs():
219
- for node in subgraph.nodes:
220
- if node not in nodes:
221
- nodes.add(node)
222
- yield node
240
+ return cls._get_nodes_from_subgraphs(cls.get_subgraphs())
223
241
 
224
242
  @classmethod
225
243
  def get_unused_subgraphs(cls) -> List[Graph]:
@@ -228,19 +246,9 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
228
246
  """
229
247
  if not hasattr(cls, "unused_graphs"):
230
248
  return []
231
-
232
249
  graphs = []
233
250
  for item in cls.unused_graphs:
234
- if isinstance(item, Graph):
235
- graphs.append(item)
236
- elif isinstance(item, set):
237
- for subitem in item:
238
- if isinstance(subitem, Graph):
239
- graphs.append(subitem)
240
- elif issubclass(subitem, BaseNode):
241
- graphs.append(Graph.from_node(subitem))
242
- elif issubclass(item, BaseNode):
243
- graphs.append(Graph.from_node(item))
251
+ graphs.extend(cls._resolve_graph(item))
244
252
  return graphs
245
253
 
246
254
  @classmethod
@@ -248,29 +256,14 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
248
256
  """
249
257
  Returns an iterator over the nodes that are defined but not used in the graph.
250
258
  """
251
- if not hasattr(cls, "unused_graphs"):
252
- yield from ()
253
- else:
254
- nodes = set()
255
- subgraphs = cls.get_unused_subgraphs()
256
- for subgraph in subgraphs:
257
- for node in subgraph.nodes:
258
- if node not in nodes:
259
- nodes.add(node)
260
- yield node
259
+ return cls._get_nodes_from_subgraphs(cls.get_unused_subgraphs())
261
260
 
262
261
  @classmethod
263
262
  def get_unused_edges(cls) -> Iterator[Edge]:
264
263
  """
265
264
  Returns an iterator over edges that are defined but not used in the graph.
266
265
  """
267
- edges = set()
268
- subgraphs = cls.get_unused_subgraphs()
269
- for subgraph in subgraphs:
270
- for edge in subgraph.edges:
271
- if edge not in edges:
272
- edges.add(edge)
273
- yield edge
266
+ return cls._get_edges_from_subgraphs(cls.get_unused_subgraphs())
274
267
 
275
268
  @classmethod
276
269
  def get_entrypoints(cls) -> Iterable[Type[BaseNode]]:
@@ -329,8 +322,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
329
322
  external_inputs=external_inputs,
330
323
  cancel_signal=cancel_signal,
331
324
  node_output_mocks=node_output_mocks,
332
- parent_context=self._context.parent_context,
333
325
  max_concurrency=max_concurrency,
326
+ init_execution_context=self._execution_context,
334
327
  ).stream()
335
328
  first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
336
329
  last_event = None
@@ -440,8 +433,8 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
440
433
  external_inputs=external_inputs,
441
434
  cancel_signal=cancel_signal,
442
435
  node_output_mocks=node_output_mocks,
443
- parent_context=self.context.parent_context,
444
436
  max_concurrency=max_concurrency,
437
+ init_execution_context=self._execution_context,
445
438
  ).stream():
446
439
  if should_yield(self.__class__, event):
447
440
  yield event
@@ -488,10 +481,19 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
488
481
  return self.get_inputs_class()()
489
482
 
490
483
  def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
484
+ execution_context = get_execution_context()
491
485
  return self.get_state_class()(
492
- meta=StateMeta(
493
- parent=self._parent_state,
494
- workflow_inputs=workflow_inputs or self.get_default_inputs(),
486
+ meta=(
487
+ StateMeta(
488
+ parent=self._parent_state,
489
+ workflow_inputs=workflow_inputs or self.get_default_inputs(),
490
+ trace_id=execution_context.trace_id,
491
+ )
492
+ if execution_context and execution_context.trace_id
493
+ else StateMeta(
494
+ parent=self._parent_state,
495
+ workflow_inputs=workflow_inputs or self.get_default_inputs(),
496
+ )
495
497
  )
496
498
  )
497
499