vellum-ai 0.10.8__py3-none-any.whl → 0.11.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (103) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/types/logical_operator.py +2 -0
  3. vellum/evaluations/resources.py +7 -12
  4. vellum/evaluations/utils/env.py +1 -3
  5. vellum/evaluations/utils/paginator.py +0 -1
  6. vellum/evaluations/utils/typing.py +1 -1
  7. vellum/evaluations/utils/uuid.py +1 -1
  8. vellum/plugins/vellum_mypy.py +3 -1
  9. vellum/workflows/descriptors/utils.py +27 -0
  10. vellum/workflows/events/__init__.py +0 -2
  11. vellum/workflows/events/node.py +7 -6
  12. vellum/workflows/events/tests/test_event.py +2 -2
  13. vellum/workflows/events/types.py +35 -30
  14. vellum/workflows/events/workflow.py +33 -8
  15. vellum/workflows/nodes/bases/base.py +49 -26
  16. vellum/workflows/nodes/bases/tests/test_base_node.py +0 -1
  17. vellum/workflows/nodes/core/templating_node/node.py +1 -0
  18. vellum/workflows/nodes/core/try_node/node.py +22 -4
  19. vellum/workflows/nodes/core/try_node/tests/test_node.py +16 -3
  20. vellum/workflows/nodes/displayable/bases/api_node/node.py +1 -1
  21. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +0 -1
  22. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +0 -1
  23. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +2 -1
  24. vellum/workflows/nodes/displayable/bases/search_node.py +0 -1
  25. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +0 -1
  26. vellum/workflows/nodes/displayable/code_execution_node/utils.py +3 -2
  27. vellum/workflows/nodes/displayable/conditional_node/node.py +1 -1
  28. vellum/workflows/nodes/displayable/guardrail_node/node.py +0 -1
  29. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +1 -0
  30. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +3 -1
  31. vellum/workflows/nodes/displayable/search_node/node.py +1 -0
  32. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +3 -2
  33. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +10 -7
  34. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +0 -1
  35. vellum/workflows/outputs/base.py +2 -4
  36. vellum/workflows/ports/node_ports.py +1 -1
  37. vellum/workflows/runner/runner.py +185 -157
  38. vellum/workflows/state/base.py +55 -23
  39. vellum/workflows/state/context.py +26 -3
  40. vellum/workflows/types/core.py +1 -0
  41. vellum/workflows/types/tests/test_utils.py +1 -0
  42. vellum/workflows/types/utils.py +0 -1
  43. vellum/workflows/utils/functions.py +74 -0
  44. vellum/workflows/utils/tests/test_functions.py +171 -0
  45. vellum/workflows/utils/tests/test_vellum_variables.py +0 -1
  46. vellum/workflows/utils/vellum_variables.py +2 -2
  47. vellum/workflows/workflows/base.py +84 -10
  48. vellum/workflows/workflows/event_filters.py +53 -0
  49. {vellum_ai-0.10.8.dist-info → vellum_ai-0.11.0.dist-info}/METADATA +1 -1
  50. {vellum_ai-0.10.8.dist-info → vellum_ai-0.11.0.dist-info}/RECORD +101 -93
  51. vellum_cli/__init__.py +147 -13
  52. vellum_cli/config.py +0 -1
  53. vellum_cli/image_push.py +1 -1
  54. vellum_cli/pull.py +29 -19
  55. vellum_cli/push.py +9 -10
  56. vellum_cli/tests/__init__.py +0 -0
  57. vellum_cli/tests/conftest.py +40 -0
  58. vellum_cli/tests/test_main.py +11 -0
  59. vellum_cli/tests/test_pull.py +125 -71
  60. vellum_cli/tests/test_push.py +173 -0
  61. vellum_ee/workflows/display/nodes/base_node_display.py +3 -2
  62. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +2 -2
  63. vellum_ee/workflows/display/nodes/get_node_display_class.py +1 -1
  64. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +1 -1
  65. vellum_ee/workflows/display/nodes/vellum/__init__.py +5 -3
  66. vellum_ee/workflows/display/nodes/vellum/api_node.py +4 -7
  67. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +39 -22
  68. vellum_ee/workflows/display/nodes/vellum/error_node.py +49 -0
  69. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +0 -2
  70. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +1 -1
  71. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +1 -1
  72. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +4 -2
  73. vellum_ee/workflows/display/nodes/vellum/map_node.py +11 -5
  74. vellum_ee/workflows/display/nodes/vellum/merge_node.py +2 -2
  75. vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -3
  76. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +1 -1
  77. vellum_ee/workflows/display/nodes/vellum/search_node.py +1 -1
  78. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +1 -1
  79. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  80. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +5 -5
  81. vellum_ee/workflows/display/nodes/vellum/utils.py +4 -4
  82. vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +45 -0
  83. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +13 -24
  84. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +13 -39
  85. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +203 -0
  86. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +2 -2
  87. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +62 -58
  88. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +25 -4
  89. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +2 -1
  90. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +2 -2
  91. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +2 -2
  92. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -1
  93. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -1
  94. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -2
  95. vellum_ee/workflows/display/types.py +4 -4
  96. vellum_ee/workflows/display/utils/vellum.py +2 -6
  97. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +4 -1
  98. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +6 -2
  99. vellum/workflows/events/utils.py +0 -5
  100. vellum/workflows/runner/types.py +0 -16
  101. {vellum_ai-0.10.8.dist-info → vellum_ai-0.11.0.dist-info}/LICENSE +0 -0
  102. {vellum_ai-0.10.8.dist-info → vellum_ai-0.11.0.dist-info}/WHEEL +0 -0
  103. {vellum_ai-0.10.8.dist-info → vellum_ai-0.11.0.dist-info}/entry_points.txt +0 -0
@@ -17,7 +17,7 @@ class BaseClientWrapper:
17
17
  headers: typing.Dict[str, str] = {
18
18
  "X-Fern-Language": "Python",
19
19
  "X-Fern-SDK-Name": "vellum-ai",
20
- "X-Fern-SDK-Version": "0.10.8",
20
+ "X-Fern-SDK-Version": "0.11.0",
21
21
  }
22
22
  headers["X_API_KEY"] = self.api_key
23
23
  return headers
@@ -24,6 +24,8 @@ LogicalOperator = typing.Union[
24
24
  "notBetween",
25
25
  "blank",
26
26
  "notBlank",
27
+ "coalesce",
28
+ "accessField",
27
29
  ],
28
30
  typing.Any,
29
31
  ]
@@ -1,17 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from functools import cached_property
3
4
  import logging
4
5
  import time
5
- from functools import cached_property
6
- from typing import Callable, Generator, List, cast, Iterable
7
6
  from uuid import UUID
7
+ from typing import Callable, Generator, Iterable, List, cast
8
8
 
9
- from vellum import TestSuiteRunRead, TestSuiteRunMetricNumberOutput
10
- from vellum.client import Vellum, OMIT
11
- from vellum.evaluations.constants import (
12
- DEFAULT_MAX_POLLING_DURATION_MS,
13
- DEFAULT_POLLING_INTERVAL_MS,
14
- )
9
+ from vellum import TestSuiteRunMetricNumberOutput, TestSuiteRunRead
10
+ from vellum.client import OMIT, Vellum
11
+ from vellum.evaluations.constants import DEFAULT_MAX_POLLING_DURATION_MS, DEFAULT_POLLING_INTERVAL_MS
15
12
  from vellum.evaluations.exceptions import TestSuiteRunResultsException
16
13
  from vellum.evaluations.utils.env import get_api_key
17
14
  from vellum.evaluations.utils.paginator import PaginatedResults, get_all_results
@@ -21,9 +18,9 @@ from vellum.types import (
21
18
  ExternalTestCaseExecutionRequest,
22
19
  NamedTestCaseVariableValueRequest,
23
20
  TestCaseVariableValue,
24
- TestSuiteRunExternalExecConfigRequest,
25
21
  TestSuiteRunExecution,
26
22
  TestSuiteRunExternalExecConfigDataRequest,
23
+ TestSuiteRunExternalExecConfigRequest,
27
24
  TestSuiteRunMetricOutput,
28
25
  TestSuiteRunState,
29
26
  )
@@ -161,9 +158,7 @@ class VellumTestSuiteRunResults:
161
158
 
162
159
  metric_outputs: list[TestSuiteRunMetricNumberOutput] = []
163
160
 
164
- for output in self.get_metric_outputs(
165
- metric_identifier=metric_identifier, output_identifier=output_identifier
166
- ):
161
+ for output in self.get_metric_outputs(metric_identifier=metric_identifier, output_identifier=output_identifier):
167
162
  if output.type != "NUMBER":
168
163
  raise TestSuiteRunResultsException(
169
164
  f"Expected a numeric metric output, but got a {output.type} output instead."
@@ -6,8 +6,6 @@ from .exceptions import VellumClientException
6
6
  def get_api_key() -> str:
7
7
  api_key = os.environ.get("VELLUM_API_KEY")
8
8
  if api_key is None:
9
- raise VellumClientException(
10
- "`VELLUM_API_KEY` environment variable is required to be set."
11
- )
9
+ raise VellumClientException("`VELLUM_API_KEY` environment variable is required to be set.")
12
10
 
13
11
  return api_key
@@ -1,7 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Callable, Generator, Generic, List, TypeVar, Union
3
3
 
4
-
5
4
  Result = TypeVar("Result")
6
5
 
7
6
 
@@ -1,4 +1,4 @@
1
- from typing import TypeVar, Optional
1
+ from typing import Optional, TypeVar
2
2
 
3
3
  _T = TypeVar("_T")
4
4
 
@@ -1,5 +1,5 @@
1
- from typing import Union
2
1
  import uuid
2
+ from typing import Union
3
3
 
4
4
 
5
5
  def is_valid_uuid(val: Union[str, uuid.UUID, None]) -> bool:
@@ -154,7 +154,9 @@ class VellumMypyPlugin(Plugin):
154
154
  def _base_class_hook(self, ctx: ClassDefContext) -> None:
155
155
  if _is_subclass(ctx.cls.info, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
156
156
  self._dynamic_output_node_class_hook(ctx, "result")
157
- elif _is_subclass(ctx.cls.info, "vellum.workflows.nodes.displayable.code_execution_node.node.CodeExecutionNode"):
157
+ elif _is_subclass(
158
+ ctx.cls.info, "vellum.workflows.nodes.displayable.code_execution_node.node.CodeExecutionNode"
159
+ ):
158
160
  self._dynamic_output_node_class_hook(ctx, "result")
159
161
  elif _is_subclass(ctx.cls.info, "vellum.workflows.nodes.displayable.final_output_node.node.FinalOutputNode"):
160
162
  self._dynamic_output_node_class_hook(ctx, "value")
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, ove
5
5
 
6
6
  from pydantic import BaseModel
7
7
 
8
+ from vellum.workflows.constants import UNDEF
8
9
  from vellum.workflows.descriptors.base import BaseDescriptor
9
10
  from vellum.workflows.state.base import BaseState
10
11
 
@@ -88,3 +89,29 @@ def resolve_value(
88
89
  return cast(_T, set_value)
89
90
 
90
91
  return value
92
+
93
+
94
+ def is_unresolved(value: Any) -> bool:
95
+ """
96
+ Recursively checks if a value has an unresolved value, represented by UNDEF.
97
+ """
98
+
99
+ if value is UNDEF:
100
+ return True
101
+
102
+ if dataclasses.is_dataclass(value):
103
+ return any(is_unresolved(getattr(value, field.name)) for field in dataclasses.fields(value))
104
+
105
+ if isinstance(value, BaseModel):
106
+ return any(is_unresolved(getattr(value, key)) for key in value.model_fields.keys())
107
+
108
+ if isinstance(value, Mapping):
109
+ return any(is_unresolved(item) for item in value.values())
110
+
111
+ if isinstance(value, Sequence):
112
+ return any(is_unresolved(item) for item in value)
113
+
114
+ if isinstance(value, Set):
115
+ return any(is_unresolved(item) for item in value)
116
+
117
+ return False
@@ -5,7 +5,6 @@ from .node import (
5
5
  NodeExecutionRejectedEvent,
6
6
  NodeExecutionStreamingEvent,
7
7
  )
8
- from .types import WorkflowEventType
9
8
  from .workflow import (
10
9
  WorkflowEvent,
11
10
  WorkflowEventStream,
@@ -27,5 +26,4 @@ __all__ = [
27
26
  "WorkflowExecutionStreamingEvent",
28
27
  "WorkflowEvent",
29
28
  "WorkflowEventStream",
30
- "WorkflowEventType",
31
29
  ]
@@ -1,12 +1,10 @@
1
- from typing import Any, Dict, Generic, Iterable, List, Literal, Optional, Set, Type, Union
1
+ from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Set, Type, Union
2
2
 
3
- from pydantic import ConfigDict, SerializerFunctionWrapHandler, field_serializer, model_serializer
4
- from pydantic.main import IncEx
3
+ from pydantic import SerializerFunctionWrapHandler, field_serializer, model_serializer
5
4
 
6
5
  from vellum.core.pydantic_utilities import UniversalBaseModel
7
6
  from vellum.workflows.errors import VellumError
8
7
  from vellum.workflows.expressions.accessor import AccessorExpression
9
- from vellum.workflows.nodes.bases import BaseNode
10
8
  from vellum.workflows.outputs.base import BaseOutput
11
9
  from vellum.workflows.ports.port import Port
12
10
  from vellum.workflows.references.node import NodeReference
@@ -14,9 +12,12 @@ from vellum.workflows.types.generics import OutputsType
14
12
 
15
13
  from .types import BaseEvent, default_serializer, serialize_type_encoder
16
14
 
15
+ if TYPE_CHECKING:
16
+ from vellum.workflows.nodes.bases import BaseNode
17
+
17
18
 
18
19
  class _BaseNodeExecutionBody(UniversalBaseModel):
19
- node_definition: Type[BaseNode]
20
+ node_definition: Type["BaseNode"]
20
21
 
21
22
  @field_serializer("node_definition")
22
23
  def serialize_node_definition(self, node_definition: Type, _info: Any) -> Dict[str, Any]:
@@ -36,7 +37,7 @@ class _BaseNodeEvent(BaseEvent):
36
37
  body: _BaseNodeExecutionBody
37
38
 
38
39
  @property
39
- def node_definition(self) -> Type[BaseNode]:
40
+ def node_definition(self) -> Type["BaseNode"]:
40
41
  return self.body.node_definition
41
42
 
42
43
 
@@ -1,6 +1,5 @@
1
1
  import pytest
2
2
  from datetime import datetime
3
- import json
4
3
  from uuid import UUID
5
4
 
6
5
  from deepdiff import DeepDiff
@@ -100,7 +99,8 @@ module_root = name_parts[: name_parts.index("events")]
100
99
  node_definition=MockNode,
101
100
  span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
102
101
  parent=WorkflowParentContext(
103
- workflow_definition=MockWorkflow, span_id=UUID("123e4567-e89b-12d3-a456-426614174000")
102
+ workflow_definition=MockWorkflow,
103
+ span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
104
104
  ),
105
105
  ),
106
106
  ),
@@ -1,24 +1,14 @@
1
1
  from datetime import datetime
2
- from enum import Enum
3
2
  import json
4
3
  from uuid import UUID, uuid4
5
- from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Type, Union
4
+ from typing import Annotated, Any, Dict, List, Literal, Optional, Union
6
5
 
7
- from pydantic import Field, field_serializer
6
+ from pydantic import BeforeValidator, Field
8
7
 
9
8
  from vellum.core.pydantic_utilities import UniversalBaseModel
10
9
  from vellum.workflows.state.encoder import DefaultStateEncoder
11
10
  from vellum.workflows.types.utils import datetime_now
12
11
 
13
- if TYPE_CHECKING:
14
- from vellum.workflows.nodes.bases.base import BaseNode
15
- from vellum.workflows.workflows.base import BaseWorkflow
16
-
17
-
18
- class WorkflowEventType(Enum):
19
- NODE = "NODE"
20
- WORKFLOW = "WORKFLOW"
21
-
22
12
 
23
13
  def default_datetime_factory() -> datetime:
24
14
  """
@@ -47,9 +37,25 @@ def default_serializer(obj: Any) -> Any:
47
37
  )
48
38
 
49
39
 
40
+ class CodeResourceDefinition(UniversalBaseModel):
41
+ name: str
42
+ module: List[str]
43
+
44
+ @staticmethod
45
+ def encode(obj: type) -> "CodeResourceDefinition":
46
+ return CodeResourceDefinition(**serialize_type_encoder(obj))
47
+
48
+
49
+ VellumCodeResourceDefinition = Annotated[
50
+ CodeResourceDefinition,
51
+ BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder(d))),
52
+ ]
53
+
54
+
50
55
  class BaseParentContext(UniversalBaseModel):
51
56
  span_id: UUID
52
- parent: Optional['ParentContext'] = None
57
+ parent: Optional["ParentContext"] = None
58
+ type: str
53
59
 
54
60
 
55
61
  class BaseDeploymentParentContext(BaseParentContext):
@@ -73,29 +79,28 @@ class PromptDeploymentParentContext(BaseDeploymentParentContext):
73
79
 
74
80
  class NodeParentContext(BaseParentContext):
75
81
  type: Literal["WORKFLOW_NODE"] = "WORKFLOW_NODE"
76
- node_definition: Type['BaseNode']
77
-
78
- @field_serializer("node_definition")
79
- def serialize_node_definition(self, definition: Type, _info: Any) -> Dict[str, Any]:
80
- return serialize_type_encoder(definition)
82
+ node_definition: VellumCodeResourceDefinition
81
83
 
82
84
 
83
85
  class WorkflowParentContext(BaseParentContext):
84
86
  type: Literal["WORKFLOW"] = "WORKFLOW"
85
- workflow_definition: Type['BaseWorkflow']
86
-
87
- @field_serializer("workflow_definition")
88
- def serialize_workflow_definition(self, definition: Type, _info: Any) -> Dict[str, Any]:
89
- return serialize_type_encoder(definition)
90
-
91
-
92
- ParentContext = Union[
93
- NodeParentContext,
94
- WorkflowParentContext,
95
- PromptDeploymentParentContext,
96
- WorkflowDeploymentParentContext,
87
+ workflow_definition: VellumCodeResourceDefinition
88
+
89
+
90
+ # Define the discriminated union
91
+ ParentContext = Annotated[
92
+ Union[
93
+ WorkflowParentContext,
94
+ NodeParentContext,
95
+ WorkflowDeploymentParentContext,
96
+ PromptDeploymentParentContext,
97
+ ],
98
+ Field(discriminator="type"),
97
99
  ]
98
100
 
101
+ # Update the forward references
102
+ BaseParentContext.model_rebuild()
103
+
99
104
 
100
105
  class BaseEvent(UniversalBaseModel):
101
106
  id: UUID = Field(default_factory=uuid4)
@@ -3,11 +3,10 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, Iterable, Liter
3
3
  from pydantic import field_serializer
4
4
 
5
5
  from vellum.core.pydantic_utilities import UniversalBaseModel
6
-
7
6
  from vellum.workflows.errors import VellumError
8
7
  from vellum.workflows.outputs.base import BaseOutput
9
8
  from vellum.workflows.references import ExternalInputReference
10
- from vellum.workflows.types.generics import OutputsType, WorkflowInputsType
9
+ from vellum.workflows.types.generics import OutputsType, StateType, WorkflowInputsType
11
10
 
12
11
  from .node import (
13
12
  NodeExecutionFulfilledEvent,
@@ -31,6 +30,14 @@ class _BaseWorkflowExecutionBody(UniversalBaseModel):
31
30
  return serialize_type_encoder(workflow_definition)
32
31
 
33
32
 
33
+ class _BaseWorkflowEvent(BaseEvent):
34
+ body: _BaseWorkflowExecutionBody
35
+
36
+ @property
37
+ def workflow_definition(self) -> Type["BaseWorkflow"]:
38
+ return self.body.workflow_definition
39
+
40
+
34
41
  class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[WorkflowInputsType]):
35
42
  inputs: WorkflowInputsType
36
43
 
@@ -39,7 +46,7 @@ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[Workflo
39
46
  return default_serializer(inputs)
40
47
 
41
48
 
42
- class WorkflowExecutionInitiatedEvent(BaseEvent, Generic[WorkflowInputsType]):
49
+ class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[WorkflowInputsType]):
43
50
  name: Literal["workflow.execution.initiated"] = "workflow.execution.initiated"
44
51
  body: WorkflowExecutionInitiatedBody[WorkflowInputsType]
45
52
 
@@ -56,7 +63,7 @@ class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
56
63
  return default_serializer(output)
57
64
 
58
65
 
59
- class WorkflowExecutionStreamingEvent(BaseEvent):
66
+ class WorkflowExecutionStreamingEvent(_BaseWorkflowEvent):
60
67
  name: Literal["workflow.execution.streaming"] = "workflow.execution.streaming"
61
68
  body: WorkflowExecutionStreamingBody
62
69
 
@@ -73,7 +80,7 @@ class WorkflowExecutionFulfilledBody(_BaseWorkflowExecutionBody, Generic[Outputs
73
80
  return default_serializer(outputs)
74
81
 
75
82
 
76
- class WorkflowExecutionFulfilledEvent(BaseEvent, Generic[OutputsType]):
83
+ class WorkflowExecutionFulfilledEvent(_BaseWorkflowEvent, Generic[OutputsType]):
77
84
  name: Literal["workflow.execution.fulfilled"] = "workflow.execution.fulfilled"
78
85
  body: WorkflowExecutionFulfilledBody[OutputsType]
79
86
 
@@ -86,7 +93,7 @@ class WorkflowExecutionRejectedBody(_BaseWorkflowExecutionBody):
86
93
  error: VellumError
87
94
 
88
95
 
89
- class WorkflowExecutionRejectedEvent(BaseEvent):
96
+ class WorkflowExecutionRejectedEvent(_BaseWorkflowEvent):
90
97
  name: Literal["workflow.execution.rejected"] = "workflow.execution.rejected"
91
98
  body: WorkflowExecutionRejectedBody
92
99
 
@@ -99,7 +106,7 @@ class WorkflowExecutionPausedBody(_BaseWorkflowExecutionBody):
99
106
  external_inputs: Iterable[ExternalInputReference]
100
107
 
101
108
 
102
- class WorkflowExecutionPausedEvent(BaseEvent):
109
+ class WorkflowExecutionPausedEvent(_BaseWorkflowEvent):
103
110
  name: Literal["workflow.execution.paused"] = "workflow.execution.paused"
104
111
  body: WorkflowExecutionPausedBody
105
112
 
@@ -112,11 +119,28 @@ class WorkflowExecutionResumedBody(_BaseWorkflowExecutionBody):
112
119
  pass
113
120
 
114
121
 
115
- class WorkflowExecutionResumedEvent(BaseEvent):
122
+ class WorkflowExecutionResumedEvent(_BaseWorkflowEvent):
116
123
  name: Literal["workflow.execution.resumed"] = "workflow.execution.resumed"
117
124
  body: WorkflowExecutionResumedBody
118
125
 
119
126
 
127
+ class WorkflowExecutionSnapshottedBody(_BaseWorkflowExecutionBody, Generic[StateType]):
128
+ state: StateType
129
+
130
+ @field_serializer("state")
131
+ def serialize_state(self, state: StateType, _info: Any) -> Dict[str, Any]:
132
+ return default_serializer(state)
133
+
134
+
135
+ class WorkflowExecutionSnapshottedEvent(_BaseWorkflowEvent, Generic[StateType]):
136
+ name: Literal["workflow.execution.snapshotted"] = "workflow.execution.snapshotted"
137
+ body: WorkflowExecutionSnapshottedBody[StateType]
138
+
139
+ @property
140
+ def state(self) -> StateType:
141
+ return self.body.state
142
+
143
+
120
144
  GenericWorkflowEvent = Union[
121
145
  WorkflowExecutionStreamingEvent,
122
146
  WorkflowExecutionRejectedEvent,
@@ -134,6 +158,7 @@ WorkflowEvent = Union[
134
158
  GenericWorkflowEvent,
135
159
  WorkflowExecutionInitiatedEvent,
136
160
  WorkflowExecutionFulfilledEvent,
161
+ WorkflowExecutionSnapshottedEvent,
137
162
  ]
138
163
 
139
164
  WorkflowEventStream = Generator[WorkflowEvent, None, None]
@@ -1,12 +1,12 @@
1
1
  from functools import cached_property, reduce
2
2
  import inspect
3
3
  from types import MappingProxyType
4
+ from uuid import UUID
4
5
  from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union, cast, get_args
5
6
 
6
7
  from vellum.workflows.constants import UNDEF
7
8
  from vellum.workflows.descriptors.base import BaseDescriptor
8
- from vellum.workflows.descriptors.utils import resolve_value
9
- from vellum.workflows.edges.edge import Edge
9
+ from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
10
10
  from vellum.workflows.errors.types import VellumErrorCode
11
11
  from vellum.workflows.exceptions import NodeException
12
12
  from vellum.workflows.graph import Graph
@@ -44,7 +44,11 @@ class BaseNodeMeta(type):
44
44
  if "Outputs" not in dct:
45
45
  for base in reversed(bases):
46
46
  if hasattr(base, "Outputs"):
47
- dct["Outputs"] = type(f"{name}.Outputs", (base.Outputs,), {"__module__": dct["__module__"]})
47
+ dct["Outputs"] = type(
48
+ f"{name}.Outputs",
49
+ (base.Outputs,),
50
+ {"__module__": dct["__module__"]},
51
+ )
48
52
  break
49
53
  else:
50
54
  raise ValueError("Outputs class not found in base classes")
@@ -66,13 +70,20 @@ class BaseNodeMeta(type):
66
70
  if "Execution" not in dct:
67
71
  for base in reversed(bases):
68
72
  if issubclass(base, BaseNode):
69
- dct["Execution"] = type(f"{name}.Execution", (base.Execution,), {"__module__": dct["__module__"]})
73
+ dct["Execution"] = type(
74
+ f"{name}.Execution",
75
+ (base.Execution,),
76
+ {"__module__": dct["__module__"]},
77
+ )
70
78
  break
71
79
 
72
80
  if "Trigger" not in dct:
73
81
  for base in reversed(bases):
74
82
  if issubclass(base, BaseNode):
75
- trigger_dct = {**base.Trigger.__dict__, "__module__": dct["__module__"]}
83
+ trigger_dct = {
84
+ **base.Trigger.__dict__,
85
+ "__module__": dct["__module__"],
86
+ }
76
87
  dct["Trigger"] = type(f"{name}.Trigger", (base.Trigger,), trigger_dct)
77
88
  break
78
89
 
@@ -225,34 +236,40 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
225
236
 
226
237
  @classmethod
227
238
  def should_initiate(
228
- cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: "Optional[Edge]" = None
239
+ cls,
240
+ state: StateType,
241
+ dependencies: Set["Type[BaseNode]"],
242
+ node_span_id: UUID,
229
243
  ) -> bool:
230
244
  """
231
245
  Determines whether a Node's execution should be initiated. Override this method to define custom
232
246
  trigger criteria.
233
247
  """
234
248
 
235
- if cls.merge_behavior == MergeBehavior.AWAIT_ANY:
236
- if not invoked_by:
237
- return True
238
-
239
- is_ready = not state.meta.node_execution_cache.is_node_initiated(cls.node_class)
249
+ if cls.merge_behavior == MergeBehavior.AWAIT_ATTRIBUTES:
250
+ if state.meta.node_execution_cache.is_node_execution_initiated(cls.node_class, node_span_id):
251
+ return False
240
252
 
241
- invoked_identifier = str(invoked_by.from_port.node_class)
242
- node_identifier = str(cls.node_class)
253
+ is_ready = True
254
+ for descriptor in cls.node_class:
255
+ if not descriptor.instance:
256
+ continue
243
257
 
244
- dependencies_invoked = state.meta.node_execution_cache.dependencies_invoked[node_identifier]
245
- dependencies_invoked.add(invoked_identifier)
246
- if all(str(dep) in dependencies_invoked for dep in dependencies):
247
- del state.meta.node_execution_cache.dependencies_invoked[node_identifier]
258
+ resolved_value = resolve_value(descriptor.instance, state, path=descriptor.name)
259
+ if is_unresolved(resolved_value):
260
+ is_ready = False
261
+ break
248
262
 
249
263
  return is_ready
250
264
 
251
- if cls.merge_behavior == MergeBehavior.AWAIT_ALL:
252
- if not invoked_by:
253
- return True
265
+ if cls.merge_behavior == MergeBehavior.AWAIT_ANY:
266
+ if state.meta.node_execution_cache.is_node_execution_initiated(cls.node_class, node_span_id):
267
+ return False
268
+
269
+ return True
254
270
 
255
- if state.meta.node_execution_cache.is_node_initiated(cls.node_class):
271
+ if cls.merge_behavior == MergeBehavior.AWAIT_ALL:
272
+ if state.meta.node_execution_cache.is_node_execution_initiated(cls.node_class, node_span_id):
256
273
  return False
257
274
 
258
275
  """
@@ -260,20 +277,26 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
260
277
  when all of its dependencies have been executed N times.
261
278
  """
262
279
  current_node_execution_count = state.meta.node_execution_cache.get_execution_count(cls.node_class)
263
- is_ready_outcome = all(
280
+ return all(
264
281
  state.meta.node_execution_cache.get_execution_count(dep) == current_node_execution_count + 1
265
282
  for dep in dependencies
266
283
  )
267
284
 
268
- return is_ready_outcome
269
-
270
- raise NodeException(message="Invalid Trigger Node Specification", code=VellumErrorCode.INVALID_INPUTS)
285
+ raise NodeException(
286
+ message="Invalid Trigger Node Specification",
287
+ code=VellumErrorCode.INVALID_INPUTS,
288
+ )
271
289
 
272
290
  class Execution(metaclass=_BaseNodeExecutionMeta):
273
291
  node_class: Type["BaseNode"]
274
292
  count: int
275
293
 
276
- def __init__(self, *, state: Optional[StateType] = None, context: Optional[WorkflowContext] = None):
294
+ def __init__(
295
+ self,
296
+ *,
297
+ state: Optional[StateType] = None,
298
+ context: Optional[WorkflowContext] = None,
299
+ ):
277
300
  if state:
278
301
  self.state = state
279
302
  else:
@@ -1,7 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
3
  from vellum.core.pydantic_utilities import UniversalBaseModel
4
-
5
4
  from vellum.workflows.inputs.base import BaseInputs
6
5
  from vellum.workflows.nodes.bases.base import BaseNode
7
6
  from vellum.workflows.state.base import BaseState, StateMeta
@@ -87,6 +87,7 @@ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metac
87
87
 
88
88
  result: _OutputType - The result of the template rendering
89
89
  """
90
+
90
91
  # We use our mypy plugin to override the _OutputType with the actual output type
91
92
  # for downstream references to this output.
92
93
  result: _OutputType # type: ignore[valid-type]
@@ -1,6 +1,6 @@
1
1
  import sys
2
2
  from types import ModuleType
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, cast
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar
4
4
 
5
5
  from vellum.workflows.errors.types import VellumError, VellumErrorCode
6
6
  from vellum.workflows.exceptions import NodeException
@@ -8,7 +8,9 @@ from vellum.workflows.nodes.bases import BaseNode
8
8
  from vellum.workflows.nodes.bases.base import BaseNodeMeta
9
9
  from vellum.workflows.nodes.utils import ADORNMENT_MODULE_NAME
10
10
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
11
+ from vellum.workflows.state.context import WorkflowContext
11
12
  from vellum.workflows.types.generics import StateType
13
+ from vellum.workflows.workflows.event_filters import all_workflow_event_filter
12
14
 
13
15
  if TYPE_CHECKING:
14
16
  from vellum.workflows import BaseWorkflow
@@ -44,6 +46,14 @@ class _TryNodeMeta(BaseNodeMeta):
44
46
 
45
47
  return node_class
46
48
 
49
+ def __getattribute__(cls, name: str) -> Any:
50
+ try:
51
+ return super().__getattribute__(name)
52
+ except AttributeError:
53
+ if name != "__wrapped_node__" and issubclass(cls, TryNode):
54
+ return getattr(cls.__wrapped_node__, name)
55
+ raise
56
+
47
57
 
48
58
  class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
49
59
  """
@@ -53,6 +63,7 @@ class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
53
63
  subworkflow: Type["BaseWorkflow"] - The Subworkflow to execute
54
64
  """
55
65
 
66
+ __wrapped_node__: Optional[Type["BaseNode"]] = None
56
67
  on_error_code: Optional[VellumErrorCode] = None
57
68
  subworkflow: Type["BaseWorkflow"]
58
69
 
@@ -62,15 +73,20 @@ class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
62
73
  def run(self) -> Iterator[BaseOutput]:
63
74
  subworkflow = self.subworkflow(
64
75
  parent_state=self.state,
65
- context=self._context,
76
+ context=WorkflowContext(
77
+ _vellum_client=self._context._vellum_client,
78
+ ),
79
+ )
80
+ subworkflow_stream = subworkflow.stream(
81
+ event_filter=all_workflow_event_filter,
66
82
  )
67
- subworkflow_stream = subworkflow.stream()
68
83
 
69
84
  outputs: Optional[BaseOutputs] = None
70
85
  exception: Optional[NodeException] = None
71
86
  fulfilled_output_names: Set[str] = set()
72
87
 
73
88
  for event in subworkflow_stream:
89
+ self._context._emit_subworkflow_event(event)
74
90
  if exception:
75
91
  continue
76
92
 
@@ -122,8 +138,9 @@ Message: {event.error.message}""",
122
138
  # https://app.shortcut.com/vellum/story/4116
123
139
  from vellum.workflows import BaseWorkflow
124
140
 
141
+ inner_cls._is_wrapped_node = True
142
+
125
143
  class Subworkflow(BaseWorkflow):
126
- inner_cls._is_wrapped_node = True
127
144
  graph = inner_cls
128
145
 
129
146
  # mypy is wrong here, this works and is defined
@@ -139,6 +156,7 @@ Message: {event.error.message}""",
139
156
  cls.__name__,
140
157
  (TryNode,),
141
158
  {
159
+ "__wrapped_node__": inner_cls,
142
160
  "__module__": dynamic_module,
143
161
  "on_error_code": _on_error_code,
144
162
  "subworkflow": Subworkflow,