vellum-ai 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. vellum/__init__.py +18 -1
  2. vellum/client/__init__.py +3 -0
  3. vellum/client/core/client_wrapper.py +2 -2
  4. vellum/client/errors/__init__.py +10 -1
  5. vellum/client/errors/too_many_requests_error.py +11 -0
  6. vellum/client/errors/unauthorized_error.py +11 -0
  7. vellum/client/reference.md +94 -0
  8. vellum/client/resources/__init__.py +2 -0
  9. vellum/client/resources/events/__init__.py +4 -0
  10. vellum/client/resources/events/client.py +165 -0
  11. vellum/client/resources/events/raw_client.py +207 -0
  12. vellum/client/types/__init__.py +6 -0
  13. vellum/client/types/error_detail_response.py +22 -0
  14. vellum/client/types/event_create_response.py +26 -0
  15. vellum/client/types/execution_thinking_vellum_value.py +1 -1
  16. vellum/client/types/thinking_vellum_value.py +1 -1
  17. vellum/client/types/thinking_vellum_value_request.py +1 -1
  18. vellum/client/types/workflow_event.py +33 -0
  19. vellum/errors/too_many_requests_error.py +3 -0
  20. vellum/errors/unauthorized_error.py +3 -0
  21. vellum/resources/events/__init__.py +3 -0
  22. vellum/resources/events/client.py +3 -0
  23. vellum/resources/events/raw_client.py +3 -0
  24. vellum/types/error_detail_response.py +3 -0
  25. vellum/types/event_create_response.py +3 -0
  26. vellum/types/workflow_event.py +3 -0
  27. vellum/workflows/nodes/displayable/bases/api_node/node.py +4 -0
  28. vellum/workflows/nodes/displayable/bases/api_node/tests/test_node.py +26 -0
  29. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +6 -1
  30. vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +22 -0
  31. vellum/workflows/sandbox.py +6 -3
  32. vellum/workflows/state/encoder.py +19 -1
  33. vellum/workflows/utils/hmac.py +44 -0
  34. {vellum_ai-1.2.0.dist-info → vellum_ai-1.2.1.dist-info}/METADATA +1 -1
  35. {vellum_ai-1.2.0.dist-info → vellum_ai-1.2.1.dist-info}/RECORD +46 -28
  36. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +33 -7
  37. vellum_ee/workflows/display/nodes/vellum/tests/test_tool_calling_node.py +239 -1
  38. vellum_ee/workflows/display/tests/test_base_workflow_display.py +53 -1
  39. vellum_ee/workflows/display/utils/expressions.py +4 -0
  40. vellum_ee/workflows/display/utils/registry.py +46 -0
  41. vellum_ee/workflows/display/workflows/base_workflow_display.py +1 -1
  42. vellum_ee/workflows/tests/test_registry.py +169 -0
  43. vellum_ee/workflows/tests/test_server.py +72 -0
  44. {vellum_ai-1.2.0.dist-info → vellum_ai-1.2.1.dist-info}/LICENSE +0 -0
  45. {vellum_ai-1.2.0.dist-info → vellum_ai-1.2.1.dist-info}/WHEEL +0 -0
  46. {vellum_ai-1.2.0.dist-info → vellum_ai-1.2.1.dist-info}/entry_points.txt +0 -0
@@ -1,11 +1,17 @@
1
1
  from uuid import UUID
2
- from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
2
+ from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
3
3
 
4
4
  from vellum import FunctionDefinition, PromptBlock, RichTextChildBlock, VellumVariable
5
5
  from vellum.workflows.descriptors.base import BaseDescriptor
6
6
  from vellum.workflows.nodes import InlinePromptNode
7
7
  from vellum.workflows.types.core import JsonObject
8
- from vellum.workflows.utils.functions import compile_function_definition
8
+ from vellum.workflows.types.definition import DeploymentDefinition
9
+ from vellum.workflows.types.generics import is_workflow_class
10
+ from vellum.workflows.utils.functions import (
11
+ compile_function_definition,
12
+ compile_inline_workflow_function_definition,
13
+ compile_workflow_deployment_function_definition,
14
+ )
9
15
  from vellum.workflows.utils.uuids import uuid4_from_hash
10
16
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
17
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
@@ -14,6 +20,9 @@ from vellum_ee.workflows.display.types import WorkflowDisplayContext
14
20
  from vellum_ee.workflows.display.utils.vellum import infer_vellum_variable_type
15
21
  from vellum_ee.workflows.display.vellum import NodeInput
16
22
 
23
+ if TYPE_CHECKING:
24
+ from vellum.workflows.workflows.base import BaseWorkflow
25
+
17
26
 
18
27
  def _contains_descriptors(obj):
19
28
  """Check if an object contains any descriptors or references that need special handling."""
@@ -69,7 +78,10 @@ class BaseInlinePromptNodeDisplay(BaseNodeDisplay[_InlinePromptNodeType], Generi
69
78
  ]
70
79
 
71
80
  functions = (
72
- [self._generate_function_tools(function, i) for i, function in enumerate(function_definitions)]
81
+ [
82
+ self._generate_function_tools(function, i, display_context)
83
+ for i, function in enumerate(function_definitions)
84
+ ]
73
85
  if isinstance(function_definitions, list)
74
86
  else []
75
87
  )
@@ -145,10 +157,24 @@ class BaseInlinePromptNodeDisplay(BaseNodeDisplay[_InlinePromptNodeType], Generi
145
157
 
146
158
  return node_inputs, prompt_inputs
147
159
 
148
- def _generate_function_tools(self, function: Union[FunctionDefinition, Callable], index: int) -> JsonObject:
149
- normalized_functions = (
150
- function if isinstance(function, FunctionDefinition) else compile_function_definition(function)
151
- )
160
+ def _generate_function_tools(
161
+ self,
162
+ function: Union[FunctionDefinition, Callable, DeploymentDefinition, Type["BaseWorkflow"]],
163
+ index: int,
164
+ display_context: WorkflowDisplayContext,
165
+ ) -> JsonObject:
166
+ if isinstance(function, FunctionDefinition):
167
+ normalized_functions = function
168
+ elif is_workflow_class(function):
169
+ normalized_functions = compile_inline_workflow_function_definition(function)
170
+ elif callable(function):
171
+ normalized_functions = compile_function_definition(function)
172
+ elif isinstance(function, DeploymentDefinition):
173
+ normalized_functions = compile_workflow_deployment_function_definition(
174
+ function.model_dump(), display_context.client
175
+ )
176
+ else:
177
+ raise ValueError(f"Unsupported function type: {type(function)}")
152
178
  return {
153
179
  "id": str(uuid4_from_hash(f"{self.node_id}-FUNCTION_DEFINITION-{index}")),
154
180
  "block_type": "FUNCTION_DEFINITION",
@@ -1,13 +1,31 @@
1
+ from datetime import datetime
2
+
1
3
  from vellum.client.types.prompt_parameters import PromptParameters
4
+ from vellum.client.types.release_review_reviewer import ReleaseReviewReviewer
5
+ from vellum.client.types.workflow_deployment_release import (
6
+ ReleaseEnvironment,
7
+ ReleaseReleaseTag,
8
+ SlimReleaseReview,
9
+ WorkflowDeploymentRelease,
10
+ WorkflowDeploymentReleaseWorkflowDeployment,
11
+ WorkflowDeploymentReleaseWorkflowVersion,
12
+ )
2
13
  from vellum.workflows import BaseWorkflow
3
14
  from vellum.workflows.inputs import BaseInputs
15
+ from vellum.workflows.nodes.bases import BaseNode
4
16
  from vellum.workflows.nodes.displayable.code_execution_node.node import CodeExecutionNode
5
17
  from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
6
18
  from vellum.workflows.nodes.displayable.tool_calling_node.node import ToolCallingNode
7
19
  from vellum.workflows.nodes.displayable.tool_calling_node.state import ToolCallingState
8
20
  from vellum.workflows.nodes.displayable.tool_calling_node.utils import create_router_node, create_tool_prompt_node
21
+ from vellum.workflows.outputs.base import BaseOutputs
9
22
  from vellum.workflows.state.base import BaseState
10
- from vellum.workflows.types.definition import AuthorizationType, EnvironmentVariableReference, MCPServer
23
+ from vellum.workflows.types.definition import (
24
+ AuthorizationType,
25
+ DeploymentDefinition,
26
+ EnvironmentVariableReference,
27
+ MCPServer,
28
+ )
11
29
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
12
30
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
13
31
 
@@ -185,6 +203,58 @@ def test_serialize_node__tool_calling_node__mcp_server_api_key():
185
203
  }
186
204
 
187
205
 
206
+ def test_serialize_node__tool_calling_node__mcp_server_no_authorization():
207
+ # GIVEN a tool calling node with an mcp server
208
+ class MyToolCallingNode(ToolCallingNode):
209
+ functions = [
210
+ MCPServer(
211
+ name="my-mcp-server",
212
+ url="https://my-mcp-server.com",
213
+ )
214
+ ]
215
+
216
+ # AND a workflow with the tool calling node
217
+ class Workflow(BaseWorkflow):
218
+ graph = MyToolCallingNode
219
+
220
+ # WHEN the workflow is serialized
221
+ workflow_display = get_workflow_display(workflow_class=Workflow)
222
+ serialized_workflow: dict = workflow_display.serialize()
223
+
224
+ # THEN the node should properly serialize the mcp server
225
+ my_tool_calling_node = next(
226
+ node
227
+ for node in serialized_workflow["workflow_raw_data"]["nodes"]
228
+ if node["id"] == str(MyToolCallingNode.__id__)
229
+ )
230
+
231
+ functions_attribute = next(
232
+ attribute for attribute in my_tool_calling_node["attributes"] if attribute["name"] == "functions"
233
+ )
234
+
235
+ assert functions_attribute == {
236
+ "id": "c8957551-cb3d-49af-8053-acd256c1d852",
237
+ "name": "functions",
238
+ "value": {
239
+ "type": "CONSTANT_VALUE",
240
+ "value": {
241
+ "type": "JSON",
242
+ "value": [
243
+ {
244
+ "type": "MCP_SERVER",
245
+ "name": "my-mcp-server",
246
+ "url": "https://my-mcp-server.com",
247
+ "authorization_type": None,
248
+ "bearer_token_value": None,
249
+ "api_key_header_key": None,
250
+ "api_key_header_value": None,
251
+ }
252
+ ],
253
+ },
254
+ },
255
+ }
256
+
257
+
188
258
  def test_serialize_tool_router_node():
189
259
  """
190
260
  Test that the tool router node created by create_router_node serializes successfully.
@@ -406,3 +476,171 @@ def test_serialize_node__tool_calling_node__subworkflow_with_parent_input_refere
406
476
  "combinator": "OR",
407
477
  },
408
478
  }
479
+
480
+
481
+ def test_serialize_tool_prompt_node_with_inline_workflow():
482
+ """
483
+ Test that the tool prompt node created by create_tool_prompt_node serializes successfully with inline workflow.
484
+ """
485
+
486
+ # GIVEN a simple inline workflow for tool calling
487
+ class SimpleWorkflowInputs(BaseInputs):
488
+ message: str
489
+
490
+ class SimpleNode(BaseNode):
491
+ message = SimpleWorkflowInputs.message
492
+
493
+ class Outputs(BaseOutputs):
494
+ result: str
495
+
496
+ def run(self) -> Outputs:
497
+ return self.Outputs(result=f"Processed: {self.message}")
498
+
499
+ class SimpleInlineWorkflow(BaseWorkflow[SimpleWorkflowInputs, BaseState]):
500
+ """A simple workflow for testing inline tool serialization."""
501
+
502
+ graph = SimpleNode
503
+
504
+ class Outputs(BaseOutputs):
505
+ result = SimpleNode.Outputs.result
506
+
507
+ # WHEN we create a tool prompt node using create_tool_prompt_node with inline workflow
508
+ tool_prompt_node = create_tool_prompt_node(
509
+ ml_model="gpt-4o-mini",
510
+ blocks=[],
511
+ functions=[SimpleInlineWorkflow],
512
+ prompt_inputs=None,
513
+ parameters=PromptParameters(),
514
+ )
515
+
516
+ tool_prompt_node_display_class = get_node_display_class(tool_prompt_node)
517
+ tool_prompt_node_display = tool_prompt_node_display_class()
518
+
519
+ # AND we create a workflow that uses this tool prompt node
520
+ class TestWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
521
+ graph = tool_prompt_node
522
+
523
+ # WHEN we serialize the entire workflow
524
+ workflow_display = get_workflow_display(workflow_class=TestWorkflow)
525
+ display_context = workflow_display.display_context
526
+ serialized_tool_prompt_node = tool_prompt_node_display.serialize(display_context)
527
+
528
+ # THEN prompt inputs should be serialized correctly
529
+ attributes = serialized_tool_prompt_node["attributes"]
530
+ assert isinstance(attributes, list)
531
+ prompt_inputs_attr = next(
532
+ (attr for attr in attributes if isinstance(attr, dict) and attr["name"] == "prompt_inputs"), None
533
+ )
534
+ assert prompt_inputs_attr == {
535
+ "id": "bc1320a2-23e4-4238-8b00-efbf88e91856",
536
+ "name": "prompt_inputs",
537
+ "value": {
538
+ "type": "DICTIONARY_REFERENCE",
539
+ "entries": [
540
+ {
541
+ "id": "76ceec7b-ec37-474f-ba38-2bfd27cecc5d",
542
+ "key": "chat_history",
543
+ "value": {
544
+ "type": "BINARY_EXPRESSION",
545
+ "lhs": {"type": "CONSTANT_VALUE", "value": {"type": "JSON", "value": []}},
546
+ "operator": "concat",
547
+ "rhs": {
548
+ "type": "WORKFLOW_STATE",
549
+ "state_variable_id": "7a1caaf5-99df-487a-8b2d-6512df2d871a",
550
+ },
551
+ },
552
+ }
553
+ ],
554
+ },
555
+ }
556
+
557
+
558
+ def test_serialize_tool_prompt_node_with_workflow_deployment(vellum_client):
559
+ """
560
+ Test that the tool prompt node serializes successfully with a workflow deployment.
561
+ """
562
+ vellum_client.workflow_deployments.retrieve_workflow_deployment_release.return_value = WorkflowDeploymentRelease(
563
+ id="test-id",
564
+ created=datetime.now(),
565
+ environment=ReleaseEnvironment(
566
+ id="test-id",
567
+ name="test-name",
568
+ label="test-label",
569
+ ),
570
+ created_by=None,
571
+ workflow_version=WorkflowDeploymentReleaseWorkflowVersion(
572
+ id="test-id",
573
+ input_variables=[],
574
+ output_variables=[],
575
+ ),
576
+ deployment=WorkflowDeploymentReleaseWorkflowDeployment(name="test-name"),
577
+ description="test-description",
578
+ release_tags=[
579
+ ReleaseReleaseTag(
580
+ name="test-name",
581
+ source="USER",
582
+ )
583
+ ],
584
+ reviews=[
585
+ SlimReleaseReview(
586
+ id="test-id",
587
+ created=datetime.now(),
588
+ reviewer=ReleaseReviewReviewer(
589
+ id="test-id",
590
+ full_name="test-name",
591
+ ),
592
+ state="APPROVED",
593
+ )
594
+ ],
595
+ )
596
+
597
+ # GIVEN a workflow deployment
598
+ workflow_deployment = DeploymentDefinition(
599
+ deployment="test-deployment",
600
+ release_tag="test-release-tag",
601
+ )
602
+
603
+ # WHEN we create a tool prompt node using create_tool_prompt_node with a workflow deployment
604
+ tool_prompt_node = create_tool_prompt_node(
605
+ ml_model="gpt-4o-mini",
606
+ blocks=[],
607
+ functions=[workflow_deployment],
608
+ prompt_inputs=None,
609
+ parameters=PromptParameters(),
610
+ )
611
+
612
+ tool_prompt_node_display_class = get_node_display_class(tool_prompt_node)
613
+ tool_prompt_node_display = tool_prompt_node_display_class()
614
+
615
+ # AND we create a workflow that uses this tool prompt node
616
+ class TestWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
617
+ graph = tool_prompt_node
618
+
619
+ # WHEN we serialize the entire workflow
620
+ workflow_display = get_workflow_display(workflow_class=TestWorkflow)
621
+ display_context = workflow_display.display_context
622
+ serialized_tool_prompt_node = tool_prompt_node_display.serialize(display_context)
623
+
624
+ # THEN functions attribute should be serialized correctly
625
+ attributes = serialized_tool_prompt_node["attributes"]
626
+ assert isinstance(attributes, list)
627
+ functions_attr = next((attr for attr in attributes if isinstance(attr, dict) and attr["name"] == "functions"), None)
628
+ assert functions_attr == {
629
+ "id": "6326ccc4-7cf6-4235-ba3c-a6e860b0c48b",
630
+ "name": "functions",
631
+ "value": {
632
+ "type": "CONSTANT_VALUE",
633
+ "value": {
634
+ "type": "JSON",
635
+ "value": [
636
+ {
637
+ "type": "WORKFLOW_DEPLOYMENT",
638
+ "name": "test-name",
639
+ "description": "test-description",
640
+ "deployment": "test-deployment",
641
+ "release_tag": "test-release-tag",
642
+ }
643
+ ],
644
+ },
645
+ },
646
+ }
@@ -2,7 +2,8 @@ from uuid import UUID
2
2
  from typing import Dict
3
3
 
4
4
  from vellum.workflows.inputs import BaseInputs
5
- from vellum.workflows.nodes import BaseNode
5
+ from vellum.workflows.nodes import BaseNode, InlineSubworkflowNode
6
+ from vellum.workflows.outputs.base import BaseOutputs
6
7
  from vellum.workflows.ports.port import Port
7
8
  from vellum.workflows.references.lazy import LazyReference
8
9
  from vellum.workflows.state import BaseState
@@ -327,3 +328,54 @@ def test_serialize__port_with_lazy_reference():
327
328
  },
328
329
  }
329
330
  ]
331
+
332
+
333
+ def test_global_propagation_deep_nested_subworkflows():
334
+ # GIVEN the root workflow, a middle workflow, and an inner workflow
335
+
336
+ class RootInputs(BaseInputs):
337
+ root_param: str
338
+
339
+ class MiddleInputs(BaseInputs):
340
+ middle_param: str
341
+
342
+ class InnerInputs(BaseInputs):
343
+ inner_param: str
344
+
345
+ class InnerNode(BaseNode):
346
+ class Outputs(BaseOutputs):
347
+ done: bool
348
+
349
+ def run(self) -> Outputs:
350
+ return self.Outputs(done=True)
351
+
352
+ class InnerWorkflow(BaseWorkflow[InnerInputs, BaseState]):
353
+ graph = InnerNode
354
+
355
+ class MiddleInlineSubworkflowNode(InlineSubworkflowNode):
356
+ subworkflow_inputs = {"inner_param": "x"}
357
+ subworkflow = InnerWorkflow
358
+
359
+ class MiddleWorkflow(BaseWorkflow[MiddleInputs, BaseState]):
360
+ graph = MiddleInlineSubworkflowNode
361
+
362
+ class OuterInlineSubworkflowNode(InlineSubworkflowNode):
363
+ subworkflow_inputs = {"middle_param": "y"}
364
+ subworkflow = MiddleWorkflow
365
+
366
+ class RootWorkflow(BaseWorkflow[RootInputs, BaseState]):
367
+ graph = OuterInlineSubworkflowNode
368
+
369
+ # WHEN we build the displays
370
+ root_display = get_workflow_display(workflow_class=RootWorkflow)
371
+ middle_display = get_workflow_display(
372
+ workflow_class=MiddleWorkflow, parent_display_context=root_display.display_context
373
+ )
374
+ inner_display = get_workflow_display(
375
+ workflow_class=InnerWorkflow, parent_display_context=middle_display.display_context
376
+ )
377
+
378
+ # THEN the deepest display must include root + middle + inner inputs in its GLOBAL view
379
+ inner_global_names = {ref.name for ref in inner_display.display_context.global_workflow_input_displays.keys()}
380
+
381
+ assert inner_global_names == {"middle_param", "inner_param", "root_param"}
@@ -11,6 +11,7 @@ from vellum.workflows.expressions.and_ import AndExpression
11
11
  from vellum.workflows.expressions.begins_with import BeginsWithExpression
12
12
  from vellum.workflows.expressions.between import BetweenExpression
13
13
  from vellum.workflows.expressions.coalesce_expression import CoalesceExpression
14
+ from vellum.workflows.expressions.concat import ConcatExpression
14
15
  from vellum.workflows.expressions.contains import ContainsExpression
15
16
  from vellum.workflows.expressions.does_not_begin_with import DoesNotBeginWithExpression
16
17
  from vellum.workflows.expressions.does_not_contain import DoesNotContainExpression
@@ -105,6 +106,8 @@ def convert_descriptor_to_operator(descriptor: BaseDescriptor) -> LogicalOperato
105
106
  return "+"
106
107
  elif isinstance(descriptor, MinusExpression):
107
108
  return "-"
109
+ elif isinstance(descriptor, ConcatExpression):
110
+ return "concat"
108
111
  else:
109
112
  raise ValueError(f"Unsupported descriptor type: {descriptor}")
110
113
 
@@ -171,6 +174,7 @@ def _serialize_condition(display_context: "WorkflowDisplayContext", condition: B
171
174
  AndExpression,
172
175
  BeginsWithExpression,
173
176
  CoalesceExpression,
177
+ ConcatExpression,
174
178
  ContainsExpression,
175
179
  DoesNotBeginWithExpression,
176
180
  DoesNotContainExpression,
@@ -1,10 +1,14 @@
1
+ from uuid import UUID
1
2
  from typing import TYPE_CHECKING, Dict, Optional, Type
2
3
 
4
+ from vellum.workflows.events.types import BaseEvent
3
5
  from vellum.workflows.nodes import BaseNode
4
6
  from vellum.workflows.workflows.base import BaseWorkflow
5
7
 
6
8
  if TYPE_CHECKING:
9
+ from vellum.workflows.events.types import ParentContext
7
10
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
11
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
8
12
  from vellum_ee.workflows.display.workflows.base_workflow_display import BaseWorkflowDisplay
9
13
 
10
14
 
@@ -14,6 +18,9 @@ _workflow_display_registry: Dict[Type[BaseWorkflow], Type["BaseWorkflowDisplay"]
14
18
  # Used to store the mapping between node types and their display classes
15
19
  _node_display_registry: Dict[Type[BaseNode], Type["BaseNodeDisplay"]] = {}
16
20
 
21
+ # Registry to store active workflow display contexts by span ID for nested workflow inheritance
22
+ _active_workflow_display_contexts: Dict[UUID, "WorkflowDisplayContext"] = {}
23
+
17
24
 
18
25
  def get_from_workflow_display_registry(workflow_class: Type[BaseWorkflow]) -> Optional[Type["BaseWorkflowDisplay"]]:
19
26
  return _workflow_display_registry.get(workflow_class)
@@ -35,3 +42,42 @@ def get_from_node_display_registry(node_class: Type[BaseNode]) -> Optional[Type[
35
42
 
36
43
  def register_node_display_class(node_class: Type[BaseNode], node_display_class: Type["BaseNodeDisplay"]) -> None:
37
44
  _node_display_registry[node_class] = node_display_class
45
+
46
+
47
+ def register_workflow_display_context(span_id: UUID, display_context: "WorkflowDisplayContext") -> None:
48
+ """Register a workflow display context by span ID for nested workflow inheritance."""
49
+ _active_workflow_display_contexts[span_id] = display_context
50
+
51
+
52
+ def _get_parent_display_context_for_span(span_id: UUID) -> Optional["WorkflowDisplayContext"]:
53
+ """Get the parent display context for a given span ID."""
54
+ return _active_workflow_display_contexts.get(span_id)
55
+
56
+
57
+ def get_parent_display_context_from_event(event: BaseEvent) -> Optional["WorkflowDisplayContext"]:
58
+ """Extract parent display context from an event by traversing the parent chain.
59
+
60
+ This function traverses up the parent chain starting from the event's parent,
61
+ looking for workflow parents and attempting to get their display context.
62
+
63
+ Args:
64
+ event: The event to extract parent display context from
65
+
66
+ Returns:
67
+ The parent workflow display context if found, None otherwise
68
+ """
69
+ if not event.parent:
70
+ return None
71
+
72
+ current_parent: Optional["ParentContext"] = event.parent
73
+ while current_parent:
74
+ if current_parent.type == "WORKFLOW":
75
+ # Found a parent workflow, try to get its display context
76
+ parent_span_id = current_parent.span_id
77
+ parent_display_context = _get_parent_display_context_for_span(parent_span_id)
78
+ if parent_display_context:
79
+ return parent_display_context
80
+ # Move up the parent chain
81
+ current_parent = current_parent.parent
82
+
83
+ return None
@@ -528,7 +528,7 @@ class BaseWorkflowDisplay(Generic[WorkflowType]):
528
528
  workflow_input_displays: WorkflowInputsDisplays = {}
529
529
  # If we're dealing with a nested workflow, then it should have access to the inputs of its parents.
530
530
  global_workflow_input_displays = (
531
- copy(self._parent_display_context.workflow_input_displays) if self._parent_display_context else {}
531
+ copy(self._parent_display_context.global_workflow_input_displays) if self._parent_display_context else {}
532
532
  )
533
533
  for workflow_input in self._workflow.get_inputs_class():
534
534
  workflow_input_display_overrides = self.inputs_display.get(workflow_input)
@@ -0,0 +1,169 @@
1
+ from datetime import datetime, timezone
2
+ from uuid import uuid4
3
+
4
+ from vellum.workflows.events.types import NodeParentContext, WorkflowParentContext
5
+ from vellum.workflows.events.workflow import WorkflowExecutionInitiatedBody, WorkflowExecutionInitiatedEvent
6
+ from vellum.workflows.inputs.base import BaseInputs
7
+ from vellum.workflows.nodes import BaseNode
8
+ from vellum.workflows.state.base import BaseState
9
+ from vellum.workflows.workflows.base import BaseWorkflow
10
+ from vellum_ee.workflows.display.utils.registry import (
11
+ get_parent_display_context_from_event,
12
+ register_workflow_display_context,
13
+ )
14
+
15
+
16
+ class MockInputs(BaseInputs):
17
+ pass
18
+
19
+
20
+ class MockState(BaseState):
21
+ pass
22
+
23
+
24
+ class MockNode(BaseNode):
25
+ pass
26
+
27
+
28
+ class MockWorkflow(BaseWorkflow[MockInputs, MockState]):
29
+ pass
30
+
31
+
32
+ class MockWorkflowDisplayContext:
33
+ pass
34
+
35
+
36
+ def test_get_parent_display_context_from_event__no_parent():
37
+ """Test event with no parent returns None"""
38
+ # GIVEN a workflow execution initiated event with no parent
39
+ event: WorkflowExecutionInitiatedEvent = WorkflowExecutionInitiatedEvent(
40
+ id=uuid4(),
41
+ timestamp=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
42
+ trace_id=uuid4(),
43
+ span_id=uuid4(),
44
+ body=WorkflowExecutionInitiatedBody(
45
+ workflow_definition=MockWorkflow,
46
+ inputs=MockInputs(),
47
+ ),
48
+ parent=None, # No parent
49
+ )
50
+
51
+ # WHEN getting parent display context
52
+ result = get_parent_display_context_from_event(event)
53
+
54
+ # THEN it should return None
55
+ assert result is None
56
+
57
+
58
+ def test_get_parent_display_context_from_event__non_workflow_parent():
59
+ """Test event with non-workflow parent continues traversal"""
60
+ # GIVEN an event with a non-workflow parent (NodeParentContext)
61
+ non_workflow_parent = NodeParentContext(node_definition=MockNode, span_id=uuid4(), parent=None)
62
+
63
+ event: WorkflowExecutionInitiatedEvent = WorkflowExecutionInitiatedEvent(
64
+ id=uuid4(),
65
+ timestamp=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
66
+ trace_id=uuid4(),
67
+ span_id=uuid4(),
68
+ body=WorkflowExecutionInitiatedBody(
69
+ workflow_definition=MockWorkflow,
70
+ inputs=MockInputs(),
71
+ ),
72
+ parent=non_workflow_parent,
73
+ )
74
+
75
+ # WHEN getting parent display context
76
+ result = get_parent_display_context_from_event(event)
77
+
78
+ # THEN it should return None (no workflow parent found)
79
+ assert result is None
80
+
81
+
82
+ def test_get_parent_display_context_from_event__nested_workflow_parents():
83
+ """Test event with nested workflow parents traverses correctly"""
84
+ # GIVEN a chain of nested contexts:
85
+ # Event -> WorkflowParent -> NodeParent -> MiddleWorkflowParent -> NodeParent
86
+
87
+ # Top level workflow parent
88
+ top_workflow_span_id = uuid4()
89
+ top_context = MockWorkflowDisplayContext()
90
+ register_workflow_display_context(top_workflow_span_id, top_context) # type: ignore[arg-type]
91
+
92
+ top_workflow_parent = WorkflowParentContext(
93
+ workflow_definition=MockWorkflow, span_id=top_workflow_span_id, parent=None
94
+ )
95
+
96
+ top_node_parent = NodeParentContext(node_definition=MockNode, span_id=uuid4(), parent=top_workflow_parent)
97
+
98
+ # AND middle workflow parent (no display context)
99
+ middle_workflow_span_id = uuid4()
100
+ middle_workflow_parent = WorkflowParentContext(
101
+ workflow_definition=MockWorkflow, span_id=middle_workflow_span_id, parent=top_node_parent
102
+ )
103
+
104
+ # AND node parent between middle workflow and event
105
+ node_parent = NodeParentContext(node_definition=MockNode, span_id=uuid4(), parent=middle_workflow_parent)
106
+
107
+ event: WorkflowExecutionInitiatedEvent = WorkflowExecutionInitiatedEvent(
108
+ id=uuid4(),
109
+ timestamp=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
110
+ trace_id=uuid4(),
111
+ span_id=uuid4(),
112
+ body=WorkflowExecutionInitiatedBody(
113
+ workflow_definition=MockWorkflow,
114
+ inputs=MockInputs(),
115
+ ),
116
+ parent=node_parent,
117
+ )
118
+
119
+ # WHEN getting parent display context
120
+ result = get_parent_display_context_from_event(event)
121
+
122
+ # THEN it should find the top-level workflow context
123
+ assert result == top_context
124
+
125
+
126
+ def test_get_parent_display_context_from_event__middle_workflow_has_context():
127
+ """Test event returns middle workflow context when it's the first one with registered context"""
128
+ # GIVEN a chain of nested contexts:
129
+ # Event -> WorkflowParent -> NodeParent -> MiddleWorkflowParent -> NodeParent
130
+
131
+ top_workflow_span_id = uuid4()
132
+ top_context = MockWorkflowDisplayContext()
133
+ register_workflow_display_context(top_workflow_span_id, top_context) # type: ignore[arg-type]
134
+
135
+ top_workflow_parent = WorkflowParentContext(
136
+ workflow_definition=MockWorkflow, span_id=top_workflow_span_id, parent=None
137
+ )
138
+
139
+ # AND node parent between top workflow and middle workflow
140
+ top_node_parent = NodeParentContext(node_definition=MockNode, span_id=uuid4(), parent=top_workflow_parent)
141
+
142
+ # AND middle workflow parent
143
+ middle_workflow_span_id = uuid4()
144
+ middle_context = MockWorkflowDisplayContext()
145
+ register_workflow_display_context(middle_workflow_span_id, middle_context) # type: ignore[arg-type]
146
+
147
+ middle_workflow_parent = WorkflowParentContext(
148
+ workflow_definition=MockWorkflow, span_id=middle_workflow_span_id, parent=top_node_parent
149
+ )
150
+
151
+ node_parent = NodeParentContext(node_definition=MockNode, span_id=uuid4(), parent=middle_workflow_parent)
152
+
153
+ event: WorkflowExecutionInitiatedEvent = WorkflowExecutionInitiatedEvent(
154
+ id=uuid4(),
155
+ timestamp=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
156
+ trace_id=uuid4(),
157
+ span_id=uuid4(),
158
+ body=WorkflowExecutionInitiatedBody(
159
+ workflow_definition=MockWorkflow,
160
+ inputs=MockInputs(),
161
+ ),
162
+ parent=node_parent,
163
+ )
164
+
165
+ # WHEN getting parent display context
166
+ result = get_parent_display_context_from_event(event)
167
+
168
+ # THEN it should find the MIDDLE workflow context
169
+ assert result == middle_context