vellum-ai 0.10.8__py3-none-any.whl → 0.10.9__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/utils.py +27 -0
- vellum/workflows/events/__init__.py +0 -2
- vellum/workflows/events/tests/test_event.py +2 -1
- vellum/workflows/events/types.py +35 -29
- vellum/workflows/events/workflow.py +14 -7
- vellum/workflows/nodes/bases/base.py +100 -38
- vellum/workflows/nodes/core/try_node/node.py +22 -4
- vellum/workflows/nodes/core/try_node/tests/test_node.py +15 -0
- vellum/workflows/runner/runner.py +109 -42
- vellum/workflows/state/base.py +55 -21
- vellum/workflows/state/context.py +26 -3
- vellum/workflows/types/core.py +1 -1
- vellum/workflows/workflows/base.py +51 -17
- vellum/workflows/workflows/event_filters.py +61 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/METADATA +1 -1
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/RECORD +24 -22
- vellum_ee/workflows/display/nodes/vellum/__init__.py +6 -4
- vellum_ee/workflows/display/nodes/vellum/error_node.py +49 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +203 -0
- vellum/workflows/events/utils.py +0 -5
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/LICENSE +0 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/WHEEL +0 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/entry_points.txt +0 -0
@@ -17,7 +17,7 @@ class BaseClientWrapper:
|
|
17
17
|
headers: typing.Dict[str, str] = {
|
18
18
|
"X-Fern-Language": "Python",
|
19
19
|
"X-Fern-SDK-Name": "vellum-ai",
|
20
|
-
"X-Fern-SDK-Version": "0.10.
|
20
|
+
"X-Fern-SDK-Version": "0.10.9",
|
21
21
|
}
|
22
22
|
headers["X_API_KEY"] = self.api_key
|
23
23
|
return headers
|
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, ove
|
|
5
5
|
|
6
6
|
from pydantic import BaseModel
|
7
7
|
|
8
|
+
from vellum.workflows.constants import UNDEF
|
8
9
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
9
10
|
from vellum.workflows.state.base import BaseState
|
10
11
|
|
@@ -88,3 +89,29 @@ def resolve_value(
|
|
88
89
|
return cast(_T, set_value)
|
89
90
|
|
90
91
|
return value
|
92
|
+
|
93
|
+
|
94
|
+
def is_unresolved(value: Any) -> bool:
|
95
|
+
"""
|
96
|
+
Recursively checks if a value has an unresolved value, represented by UNDEF.
|
97
|
+
"""
|
98
|
+
|
99
|
+
if value is UNDEF:
|
100
|
+
return True
|
101
|
+
|
102
|
+
if dataclasses.is_dataclass(value):
|
103
|
+
return any(is_unresolved(getattr(value, field.name)) for field in dataclasses.fields(value))
|
104
|
+
|
105
|
+
if isinstance(value, BaseModel):
|
106
|
+
return any(is_unresolved(getattr(value, key)) for key in value.model_fields.keys())
|
107
|
+
|
108
|
+
if isinstance(value, Mapping):
|
109
|
+
return any(is_unresolved(item) for item in value.values())
|
110
|
+
|
111
|
+
if isinstance(value, Sequence):
|
112
|
+
return any(is_unresolved(item) for item in value)
|
113
|
+
|
114
|
+
if isinstance(value, Set):
|
115
|
+
return any(is_unresolved(item) for item in value)
|
116
|
+
|
117
|
+
return False
|
@@ -5,7 +5,6 @@ from .node import (
|
|
5
5
|
NodeExecutionRejectedEvent,
|
6
6
|
NodeExecutionStreamingEvent,
|
7
7
|
)
|
8
|
-
from .types import WorkflowEventType
|
9
8
|
from .workflow import (
|
10
9
|
WorkflowEvent,
|
11
10
|
WorkflowEventStream,
|
@@ -27,5 +26,4 @@ __all__ = [
|
|
27
26
|
"WorkflowExecutionStreamingEvent",
|
28
27
|
"WorkflowEvent",
|
29
28
|
"WorkflowEventStream",
|
30
|
-
"WorkflowEventType",
|
31
29
|
]
|
@@ -100,7 +100,8 @@ module_root = name_parts[: name_parts.index("events")]
|
|
100
100
|
node_definition=MockNode,
|
101
101
|
span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
102
102
|
parent=WorkflowParentContext(
|
103
|
-
workflow_definition=MockWorkflow,
|
103
|
+
workflow_definition=MockWorkflow,
|
104
|
+
span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
104
105
|
),
|
105
106
|
),
|
106
107
|
),
|
vellum/workflows/events/types.py
CHANGED
@@ -2,23 +2,14 @@ from datetime import datetime
|
|
2
2
|
from enum import Enum
|
3
3
|
import json
|
4
4
|
from uuid import UUID, uuid4
|
5
|
-
from typing import
|
5
|
+
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
6
6
|
|
7
|
-
from pydantic import
|
7
|
+
from pydantic import BeforeValidator, Field
|
8
8
|
|
9
9
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
10
10
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
11
11
|
from vellum.workflows.types.utils import datetime_now
|
12
12
|
|
13
|
-
if TYPE_CHECKING:
|
14
|
-
from vellum.workflows.nodes.bases.base import BaseNode
|
15
|
-
from vellum.workflows.workflows.base import BaseWorkflow
|
16
|
-
|
17
|
-
|
18
|
-
class WorkflowEventType(Enum):
|
19
|
-
NODE = "NODE"
|
20
|
-
WORKFLOW = "WORKFLOW"
|
21
|
-
|
22
13
|
|
23
14
|
def default_datetime_factory() -> datetime:
|
24
15
|
"""
|
@@ -47,9 +38,25 @@ def default_serializer(obj: Any) -> Any:
|
|
47
38
|
)
|
48
39
|
|
49
40
|
|
41
|
+
class CodeResourceDefinition(UniversalBaseModel):
|
42
|
+
name: str
|
43
|
+
module: List[str]
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def encode(obj: type) -> "CodeResourceDefinition":
|
47
|
+
return CodeResourceDefinition(**serialize_type_encoder(obj))
|
48
|
+
|
49
|
+
|
50
|
+
VellumCodeResourceDefinition = Annotated[
|
51
|
+
CodeResourceDefinition,
|
52
|
+
BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder(d))),
|
53
|
+
]
|
54
|
+
|
55
|
+
|
50
56
|
class BaseParentContext(UniversalBaseModel):
|
51
57
|
span_id: UUID
|
52
|
-
parent: Optional[
|
58
|
+
parent: Optional["ParentContext"] = None
|
59
|
+
type: str
|
53
60
|
|
54
61
|
|
55
62
|
class BaseDeploymentParentContext(BaseParentContext):
|
@@ -73,29 +80,28 @@ class PromptDeploymentParentContext(BaseDeploymentParentContext):
|
|
73
80
|
|
74
81
|
class NodeParentContext(BaseParentContext):
|
75
82
|
type: Literal["WORKFLOW_NODE"] = "WORKFLOW_NODE"
|
76
|
-
node_definition:
|
77
|
-
|
78
|
-
@field_serializer("node_definition")
|
79
|
-
def serialize_node_definition(self, definition: Type, _info: Any) -> Dict[str, Any]:
|
80
|
-
return serialize_type_encoder(definition)
|
83
|
+
node_definition: VellumCodeResourceDefinition
|
81
84
|
|
82
85
|
|
83
86
|
class WorkflowParentContext(BaseParentContext):
|
84
87
|
type: Literal["WORKFLOW"] = "WORKFLOW"
|
85
|
-
workflow_definition:
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
88
|
+
workflow_definition: VellumCodeResourceDefinition
|
89
|
+
|
90
|
+
|
91
|
+
# Define the discriminated union
|
92
|
+
ParentContext = Annotated[
|
93
|
+
Union[
|
94
|
+
WorkflowParentContext,
|
95
|
+
NodeParentContext,
|
96
|
+
WorkflowDeploymentParentContext,
|
97
|
+
PromptDeploymentParentContext,
|
98
|
+
],
|
99
|
+
Field(discriminator="type"),
|
97
100
|
]
|
98
101
|
|
102
|
+
# Update the forward references
|
103
|
+
BaseParentContext.model_rebuild()
|
104
|
+
|
99
105
|
|
100
106
|
class BaseEvent(UniversalBaseModel):
|
101
107
|
id: UUID = Field(default_factory=uuid4)
|
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, Iterable, Liter
|
|
3
3
|
from pydantic import field_serializer
|
4
4
|
|
5
5
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
6
|
-
|
7
6
|
from vellum.workflows.errors import VellumError
|
8
7
|
from vellum.workflows.outputs.base import BaseOutput
|
9
8
|
from vellum.workflows.references import ExternalInputReference
|
@@ -31,6 +30,14 @@ class _BaseWorkflowExecutionBody(UniversalBaseModel):
|
|
31
30
|
return serialize_type_encoder(workflow_definition)
|
32
31
|
|
33
32
|
|
33
|
+
class _BaseWorkflowEvent(BaseEvent):
|
34
|
+
body: _BaseWorkflowExecutionBody
|
35
|
+
|
36
|
+
@property
|
37
|
+
def workflow_definition(self) -> Type["BaseWorkflow"]:
|
38
|
+
return self.body.workflow_definition
|
39
|
+
|
40
|
+
|
34
41
|
class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[WorkflowInputsType]):
|
35
42
|
inputs: WorkflowInputsType
|
36
43
|
|
@@ -39,7 +46,7 @@ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[Workflo
|
|
39
46
|
return default_serializer(inputs)
|
40
47
|
|
41
48
|
|
42
|
-
class WorkflowExecutionInitiatedEvent(
|
49
|
+
class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[WorkflowInputsType]):
|
43
50
|
name: Literal["workflow.execution.initiated"] = "workflow.execution.initiated"
|
44
51
|
body: WorkflowExecutionInitiatedBody[WorkflowInputsType]
|
45
52
|
|
@@ -56,7 +63,7 @@ class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
|
|
56
63
|
return default_serializer(output)
|
57
64
|
|
58
65
|
|
59
|
-
class WorkflowExecutionStreamingEvent(
|
66
|
+
class WorkflowExecutionStreamingEvent(_BaseWorkflowEvent):
|
60
67
|
name: Literal["workflow.execution.streaming"] = "workflow.execution.streaming"
|
61
68
|
body: WorkflowExecutionStreamingBody
|
62
69
|
|
@@ -73,7 +80,7 @@ class WorkflowExecutionFulfilledBody(_BaseWorkflowExecutionBody, Generic[Outputs
|
|
73
80
|
return default_serializer(outputs)
|
74
81
|
|
75
82
|
|
76
|
-
class WorkflowExecutionFulfilledEvent(
|
83
|
+
class WorkflowExecutionFulfilledEvent(_BaseWorkflowEvent, Generic[OutputsType]):
|
77
84
|
name: Literal["workflow.execution.fulfilled"] = "workflow.execution.fulfilled"
|
78
85
|
body: WorkflowExecutionFulfilledBody[OutputsType]
|
79
86
|
|
@@ -86,7 +93,7 @@ class WorkflowExecutionRejectedBody(_BaseWorkflowExecutionBody):
|
|
86
93
|
error: VellumError
|
87
94
|
|
88
95
|
|
89
|
-
class WorkflowExecutionRejectedEvent(
|
96
|
+
class WorkflowExecutionRejectedEvent(_BaseWorkflowEvent):
|
90
97
|
name: Literal["workflow.execution.rejected"] = "workflow.execution.rejected"
|
91
98
|
body: WorkflowExecutionRejectedBody
|
92
99
|
|
@@ -99,7 +106,7 @@ class WorkflowExecutionPausedBody(_BaseWorkflowExecutionBody):
|
|
99
106
|
external_inputs: Iterable[ExternalInputReference]
|
100
107
|
|
101
108
|
|
102
|
-
class WorkflowExecutionPausedEvent(
|
109
|
+
class WorkflowExecutionPausedEvent(_BaseWorkflowEvent):
|
103
110
|
name: Literal["workflow.execution.paused"] = "workflow.execution.paused"
|
104
111
|
body: WorkflowExecutionPausedBody
|
105
112
|
|
@@ -112,7 +119,7 @@ class WorkflowExecutionResumedBody(_BaseWorkflowExecutionBody):
|
|
112
119
|
pass
|
113
120
|
|
114
121
|
|
115
|
-
class WorkflowExecutionResumedEvent(
|
122
|
+
class WorkflowExecutionResumedEvent(_BaseWorkflowEvent):
|
116
123
|
name: Literal["workflow.execution.resumed"] = "workflow.execution.resumed"
|
117
124
|
body: WorkflowExecutionResumedBody
|
118
125
|
|
@@ -1,13 +1,15 @@
|
|
1
1
|
from functools import cached_property, reduce
|
2
2
|
import inspect
|
3
3
|
from types import MappingProxyType
|
4
|
+
from uuid import UUID
|
4
5
|
from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union, cast, get_args
|
5
6
|
|
6
7
|
from vellum.workflows.constants import UNDEF
|
7
8
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
8
|
-
from vellum.workflows.descriptors.utils import resolve_value
|
9
|
+
from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
|
9
10
|
from vellum.workflows.edges.edge import Edge
|
10
11
|
from vellum.workflows.errors.types import VellumErrorCode
|
12
|
+
from vellum.workflows.events.types import ParentContext
|
11
13
|
from vellum.workflows.exceptions import NodeException
|
12
14
|
from vellum.workflows.graph import Graph
|
13
15
|
from vellum.workflows.graph.graph import GraphTarget
|
@@ -30,9 +32,15 @@ def is_nested_class(nested: Any, parent: Type) -> bool:
|
|
30
32
|
inspect.isclass(nested)
|
31
33
|
# If a class is defined within a function, we don't consider it nested in the class defining that function
|
32
34
|
# The example of this is a Subworkflow defined within TryNode.wrap()
|
33
|
-
and (
|
35
|
+
and (
|
36
|
+
len(nested.__qualname__.split(".")) < 2
|
37
|
+
or nested.__qualname__.split(".")[-2] != "<locals>"
|
38
|
+
)
|
34
39
|
and nested.__module__ == parent.__module__
|
35
|
-
and (
|
40
|
+
and (
|
41
|
+
nested.__qualname__.startswith(parent.__name__)
|
42
|
+
or nested.__qualname__.startswith(parent.__qualname__)
|
43
|
+
)
|
36
44
|
) or any(is_nested_class(nested, base) for base in parent.__bases__)
|
37
45
|
|
38
46
|
|
@@ -44,7 +52,11 @@ class BaseNodeMeta(type):
|
|
44
52
|
if "Outputs" not in dct:
|
45
53
|
for base in reversed(bases):
|
46
54
|
if hasattr(base, "Outputs"):
|
47
|
-
dct["Outputs"] = type(
|
55
|
+
dct["Outputs"] = type(
|
56
|
+
f"{name}.Outputs",
|
57
|
+
(base.Outputs,),
|
58
|
+
{"__module__": dct["__module__"]},
|
59
|
+
)
|
48
60
|
break
|
49
61
|
else:
|
50
62
|
raise ValueError("Outputs class not found in base classes")
|
@@ -66,14 +78,23 @@ class BaseNodeMeta(type):
|
|
66
78
|
if "Execution" not in dct:
|
67
79
|
for base in reversed(bases):
|
68
80
|
if issubclass(base, BaseNode):
|
69
|
-
dct["Execution"] = type(
|
81
|
+
dct["Execution"] = type(
|
82
|
+
f"{name}.Execution",
|
83
|
+
(base.Execution,),
|
84
|
+
{"__module__": dct["__module__"]},
|
85
|
+
)
|
70
86
|
break
|
71
87
|
|
72
88
|
if "Trigger" not in dct:
|
73
89
|
for base in reversed(bases):
|
74
90
|
if issubclass(base, BaseNode):
|
75
|
-
trigger_dct = {
|
76
|
-
|
91
|
+
trigger_dct = {
|
92
|
+
**base.Trigger.__dict__,
|
93
|
+
"__module__": dct["__module__"],
|
94
|
+
}
|
95
|
+
dct["Trigger"] = type(
|
96
|
+
f"{name}.Trigger", (base.Trigger,), trigger_dct
|
97
|
+
)
|
77
98
|
break
|
78
99
|
|
79
100
|
cls = super().__new__(mcs, name, bases, dct)
|
@@ -118,7 +139,9 @@ class BaseNodeMeta(type):
|
|
118
139
|
|
119
140
|
def __rshift__(cls, other_cls: GraphTarget) -> Graph:
|
120
141
|
if not issubclass(cls, BaseNode):
|
121
|
-
raise ValueError(
|
142
|
+
raise ValueError(
|
143
|
+
"BaseNodeMeta can only be extended from subclasses of BaseNode"
|
144
|
+
)
|
122
145
|
|
123
146
|
if not cls.Ports._default_port:
|
124
147
|
raise ValueError("No default port found on node")
|
@@ -130,7 +153,9 @@ class BaseNodeMeta(type):
|
|
130
153
|
|
131
154
|
def __rrshift__(cls, other_cls: GraphTarget) -> Graph:
|
132
155
|
if not issubclass(cls, BaseNode):
|
133
|
-
raise ValueError(
|
156
|
+
raise ValueError(
|
157
|
+
"BaseNodeMeta can only be extended from subclasses of BaseNode"
|
158
|
+
)
|
134
159
|
|
135
160
|
if not isinstance(other_cls, set):
|
136
161
|
other_cls = {other_cls}
|
@@ -168,13 +193,18 @@ class _BaseNodeTriggerMeta(type):
|
|
168
193
|
if not isinstance(other, _BaseNodeTriggerMeta):
|
169
194
|
return False
|
170
195
|
|
171
|
-
if not self.__name__.endswith(".Trigger") or not other.__name__.endswith(
|
196
|
+
if not self.__name__.endswith(".Trigger") or not other.__name__.endswith(
|
197
|
+
".Trigger"
|
198
|
+
):
|
172
199
|
return super().__eq__(other)
|
173
200
|
|
174
201
|
self_trigger_class = cast(Type["BaseNode.Trigger"], self)
|
175
202
|
other_trigger_class = cast(Type["BaseNode.Trigger"], other)
|
176
203
|
|
177
|
-
return
|
204
|
+
return (
|
205
|
+
self_trigger_class.node_class.__name__
|
206
|
+
== other_trigger_class.node_class.__name__
|
207
|
+
)
|
178
208
|
|
179
209
|
|
180
210
|
class _BaseNodeExecutionMeta(type):
|
@@ -192,13 +222,18 @@ class _BaseNodeExecutionMeta(type):
|
|
192
222
|
if not isinstance(other, _BaseNodeExecutionMeta):
|
193
223
|
return False
|
194
224
|
|
195
|
-
if not self.__name__.endswith(".Execution") or not other.__name__.endswith(
|
225
|
+
if not self.__name__.endswith(".Execution") or not other.__name__.endswith(
|
226
|
+
".Execution"
|
227
|
+
):
|
196
228
|
return super().__eq__(other)
|
197
229
|
|
198
230
|
self_execution_class = cast(Type["BaseNode.Execution"], self)
|
199
231
|
other_execution_class = cast(Type["BaseNode.Execution"], other)
|
200
232
|
|
201
|
-
return
|
233
|
+
return (
|
234
|
+
self_execution_class.node_class.__name__
|
235
|
+
== other_execution_class.node_class.__name__
|
236
|
+
)
|
202
237
|
|
203
238
|
|
204
239
|
class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
@@ -225,55 +260,78 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
225
260
|
|
226
261
|
@classmethod
|
227
262
|
def should_initiate(
|
228
|
-
cls,
|
263
|
+
cls,
|
264
|
+
state: StateType,
|
265
|
+
dependencies: Set["Type[BaseNode]"],
|
266
|
+
node_span_id: UUID,
|
229
267
|
) -> bool:
|
230
268
|
"""
|
231
269
|
Determines whether a Node's execution should be initiated. Override this method to define custom
|
232
270
|
trigger criteria.
|
233
271
|
"""
|
234
272
|
|
235
|
-
if cls.merge_behavior == MergeBehavior.
|
236
|
-
if
|
237
|
-
|
238
|
-
|
239
|
-
|
273
|
+
if cls.merge_behavior == MergeBehavior.AWAIT_ATTRIBUTES:
|
274
|
+
if state.meta.node_execution_cache.is_node_execution_initiated(
|
275
|
+
cls.node_class, node_span_id
|
276
|
+
):
|
277
|
+
return False
|
240
278
|
|
241
|
-
|
242
|
-
|
279
|
+
is_ready = True
|
280
|
+
for descriptor in cls.node_class:
|
281
|
+
if not descriptor.instance:
|
282
|
+
continue
|
243
283
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
284
|
+
resolved_value = resolve_value(
|
285
|
+
descriptor.instance, state, path=descriptor.name
|
286
|
+
)
|
287
|
+
if is_unresolved(resolved_value):
|
288
|
+
is_ready = False
|
289
|
+
break
|
248
290
|
|
249
291
|
return is_ready
|
250
292
|
|
251
|
-
if cls.merge_behavior == MergeBehavior.
|
252
|
-
if
|
253
|
-
|
293
|
+
if cls.merge_behavior == MergeBehavior.AWAIT_ANY:
|
294
|
+
if state.meta.node_execution_cache.is_node_execution_initiated(
|
295
|
+
cls.node_class, node_span_id
|
296
|
+
):
|
297
|
+
return False
|
298
|
+
|
299
|
+
return True
|
254
300
|
|
255
|
-
|
301
|
+
if cls.merge_behavior == MergeBehavior.AWAIT_ALL:
|
302
|
+
if state.meta.node_execution_cache.is_node_execution_initiated(
|
303
|
+
cls.node_class, node_span_id
|
304
|
+
):
|
256
305
|
return False
|
257
306
|
|
258
307
|
"""
|
259
308
|
A node utilizing an AWAIT_ALL merge strategy will only be considered ready for the Nth time
|
260
309
|
when all of its dependencies have been executed N times.
|
261
310
|
"""
|
262
|
-
current_node_execution_count =
|
263
|
-
|
264
|
-
|
311
|
+
current_node_execution_count = (
|
312
|
+
state.meta.node_execution_cache.get_execution_count(cls.node_class)
|
313
|
+
)
|
314
|
+
return all(
|
315
|
+
state.meta.node_execution_cache.get_execution_count(dep)
|
316
|
+
== current_node_execution_count + 1
|
265
317
|
for dep in dependencies
|
266
318
|
)
|
267
319
|
|
268
|
-
|
269
|
-
|
270
|
-
|
320
|
+
raise NodeException(
|
321
|
+
message="Invalid Trigger Node Specification",
|
322
|
+
code=VellumErrorCode.INVALID_INPUTS,
|
323
|
+
)
|
271
324
|
|
272
325
|
class Execution(metaclass=_BaseNodeExecutionMeta):
|
273
326
|
node_class: Type["BaseNode"]
|
274
327
|
count: int
|
275
328
|
|
276
|
-
def __init__(
|
329
|
+
def __init__(
|
330
|
+
self,
|
331
|
+
*,
|
332
|
+
state: Optional[StateType] = None,
|
333
|
+
context: Optional[WorkflowContext] = None,
|
334
|
+
):
|
277
335
|
if state:
|
278
336
|
self.state = state
|
279
337
|
else:
|
@@ -295,7 +353,9 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
295
353
|
if not descriptor.instance:
|
296
354
|
continue
|
297
355
|
|
298
|
-
resolved_value = resolve_value(
|
356
|
+
resolved_value = resolve_value(
|
357
|
+
descriptor.instance, self.state, path=descriptor.name, memo=inputs
|
358
|
+
)
|
299
359
|
setattr(self, descriptor.name, resolved_value)
|
300
360
|
|
301
361
|
# Resolve descriptors set as defaults to the outputs class
|
@@ -319,7 +379,9 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
319
379
|
for key, value in inputs.items():
|
320
380
|
path_parts = key.split(".")
|
321
381
|
node_attribute_discriptor = getattr(self.__class__, path_parts[0])
|
322
|
-
inputs_key = reduce(
|
382
|
+
inputs_key = reduce(
|
383
|
+
lambda acc, part: acc[part], path_parts[1:], node_attribute_discriptor
|
384
|
+
)
|
323
385
|
all_inputs[inputs_key] = value
|
324
386
|
|
325
387
|
self._inputs = MappingProxyType(all_inputs)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import sys
|
2
|
-
from types import ModuleType
|
2
|
+
from types import MappingProxyType, ModuleType
|
3
3
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, cast
|
4
4
|
|
5
5
|
from vellum.workflows.errors.types import VellumError, VellumErrorCode
|
@@ -8,7 +8,9 @@ from vellum.workflows.nodes.bases import BaseNode
|
|
8
8
|
from vellum.workflows.nodes.bases.base import BaseNodeMeta
|
9
9
|
from vellum.workflows.nodes.utils import ADORNMENT_MODULE_NAME
|
10
10
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
11
|
+
from vellum.workflows.state.context import WorkflowContext
|
11
12
|
from vellum.workflows.types.generics import StateType
|
13
|
+
from vellum.workflows.workflows.event_filters import all_workflow_event_filter
|
12
14
|
|
13
15
|
if TYPE_CHECKING:
|
14
16
|
from vellum.workflows import BaseWorkflow
|
@@ -44,6 +46,14 @@ class _TryNodeMeta(BaseNodeMeta):
|
|
44
46
|
|
45
47
|
return node_class
|
46
48
|
|
49
|
+
def __getattribute__(cls, name: str) -> Any:
|
50
|
+
try:
|
51
|
+
return super().__getattribute__(name)
|
52
|
+
except AttributeError:
|
53
|
+
if name != "__wrapped_node__" and issubclass(cls, TryNode):
|
54
|
+
return getattr(cls.__wrapped_node__, name)
|
55
|
+
raise
|
56
|
+
|
47
57
|
|
48
58
|
class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
|
49
59
|
"""
|
@@ -53,6 +63,7 @@ class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
|
|
53
63
|
subworkflow: Type["BaseWorkflow"] - The Subworkflow to execute
|
54
64
|
"""
|
55
65
|
|
66
|
+
__wrapped_node__: Optional[Type["BaseNode"]] = None
|
56
67
|
on_error_code: Optional[VellumErrorCode] = None
|
57
68
|
subworkflow: Type["BaseWorkflow"]
|
58
69
|
|
@@ -62,15 +73,20 @@ class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
|
|
62
73
|
def run(self) -> Iterator[BaseOutput]:
|
63
74
|
subworkflow = self.subworkflow(
|
64
75
|
parent_state=self.state,
|
65
|
-
context=
|
76
|
+
context=WorkflowContext(
|
77
|
+
_vellum_client=self._context._vellum_client,
|
78
|
+
),
|
79
|
+
)
|
80
|
+
subworkflow_stream = subworkflow.stream(
|
81
|
+
event_filter=all_workflow_event_filter,
|
66
82
|
)
|
67
|
-
subworkflow_stream = subworkflow.stream()
|
68
83
|
|
69
84
|
outputs: Optional[BaseOutputs] = None
|
70
85
|
exception: Optional[NodeException] = None
|
71
86
|
fulfilled_output_names: Set[str] = set()
|
72
87
|
|
73
88
|
for event in subworkflow_stream:
|
89
|
+
self._context._emit_subworkflow_event(event)
|
74
90
|
if exception:
|
75
91
|
continue
|
76
92
|
|
@@ -122,8 +138,9 @@ Message: {event.error.message}""",
|
|
122
138
|
# https://app.shortcut.com/vellum/story/4116
|
123
139
|
from vellum.workflows import BaseWorkflow
|
124
140
|
|
141
|
+
inner_cls._is_wrapped_node = True
|
142
|
+
|
125
143
|
class Subworkflow(BaseWorkflow):
|
126
|
-
inner_cls._is_wrapped_node = True
|
127
144
|
graph = inner_cls
|
128
145
|
|
129
146
|
# mypy is wrong here, this works and is defined
|
@@ -139,6 +156,7 @@ Message: {event.error.message}""",
|
|
139
156
|
cls.__name__,
|
140
157
|
(TryNode,),
|
141
158
|
{
|
159
|
+
"__wrapped_node__": inner_cls,
|
142
160
|
"__module__": dynamic_module,
|
143
161
|
"on_error_code": _on_error_code,
|
144
162
|
"subworkflow": Subworkflow,
|
@@ -111,3 +111,18 @@ def test_try_node__use_parent_execution_context():
|
|
111
111
|
# THEN the inner node had access to the key
|
112
112
|
assert len(outputs) == 1
|
113
113
|
assert outputs[-1] == BaseOutput(name="key", value="test-key")
|
114
|
+
|
115
|
+
|
116
|
+
def test_try_node__resolved_inputs():
|
117
|
+
"""
|
118
|
+
This test ensures that node attributes of TryNodes are correctly resolved.
|
119
|
+
"""
|
120
|
+
|
121
|
+
class State(BaseState):
|
122
|
+
counter = 3.0
|
123
|
+
|
124
|
+
@TryNode.wrap()
|
125
|
+
class MyNode(BaseNode[State]):
|
126
|
+
foo = State.counter
|
127
|
+
|
128
|
+
assert MyNode.foo.types == (float,)
|