vellum-ai 0.10.8__py3-none-any.whl → 0.10.9__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,7 +17,7 @@ class BaseClientWrapper:
17
17
  headers: typing.Dict[str, str] = {
18
18
  "X-Fern-Language": "Python",
19
19
  "X-Fern-SDK-Name": "vellum-ai",
20
- "X-Fern-SDK-Version": "0.10.8",
20
+ "X-Fern-SDK-Version": "0.10.9",
21
21
  }
22
22
  headers["X_API_KEY"] = self.api_key
23
23
  return headers
@@ -24,6 +24,8 @@ LogicalOperator = typing.Union[
24
24
  "notBetween",
25
25
  "blank",
26
26
  "notBlank",
27
+ "coalesce",
28
+ "accessField",
27
29
  ],
28
30
  typing.Any,
29
31
  ]
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Set, TypeVar, Union, cast, ove
5
5
 
6
6
  from pydantic import BaseModel
7
7
 
8
+ from vellum.workflows.constants import UNDEF
8
9
  from vellum.workflows.descriptors.base import BaseDescriptor
9
10
  from vellum.workflows.state.base import BaseState
10
11
 
@@ -88,3 +89,29 @@ def resolve_value(
88
89
  return cast(_T, set_value)
89
90
 
90
91
  return value
92
+
93
+
94
+ def is_unresolved(value: Any) -> bool:
95
+ """
96
+ Recursively checks if a value has an unresolved value, represented by UNDEF.
97
+ """
98
+
99
+ if value is UNDEF:
100
+ return True
101
+
102
+ if dataclasses.is_dataclass(value):
103
+ return any(is_unresolved(getattr(value, field.name)) for field in dataclasses.fields(value))
104
+
105
+ if isinstance(value, BaseModel):
106
+ return any(is_unresolved(getattr(value, key)) for key in value.model_fields.keys())
107
+
108
+ if isinstance(value, Mapping):
109
+ return any(is_unresolved(item) for item in value.values())
110
+
111
+ if isinstance(value, Sequence):
112
+ return any(is_unresolved(item) for item in value)
113
+
114
+ if isinstance(value, Set):
115
+ return any(is_unresolved(item) for item in value)
116
+
117
+ return False
@@ -5,7 +5,6 @@ from .node import (
5
5
  NodeExecutionRejectedEvent,
6
6
  NodeExecutionStreamingEvent,
7
7
  )
8
- from .types import WorkflowEventType
9
8
  from .workflow import (
10
9
  WorkflowEvent,
11
10
  WorkflowEventStream,
@@ -27,5 +26,4 @@ __all__ = [
27
26
  "WorkflowExecutionStreamingEvent",
28
27
  "WorkflowEvent",
29
28
  "WorkflowEventStream",
30
- "WorkflowEventType",
31
29
  ]
@@ -100,7 +100,8 @@ module_root = name_parts[: name_parts.index("events")]
100
100
  node_definition=MockNode,
101
101
  span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
102
102
  parent=WorkflowParentContext(
103
- workflow_definition=MockWorkflow, span_id=UUID("123e4567-e89b-12d3-a456-426614174000")
103
+ workflow_definition=MockWorkflow,
104
+ span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
104
105
  ),
105
106
  ),
106
107
  ),
@@ -2,23 +2,14 @@ from datetime import datetime
2
2
  from enum import Enum
3
3
  import json
4
4
  from uuid import UUID, uuid4
5
- from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Type, Union
5
+ from typing import Annotated, Any, Dict, List, Literal, Optional, Union
6
6
 
7
- from pydantic import Field, field_serializer
7
+ from pydantic import BeforeValidator, Field
8
8
 
9
9
  from vellum.core.pydantic_utilities import UniversalBaseModel
10
10
  from vellum.workflows.state.encoder import DefaultStateEncoder
11
11
  from vellum.workflows.types.utils import datetime_now
12
12
 
13
- if TYPE_CHECKING:
14
- from vellum.workflows.nodes.bases.base import BaseNode
15
- from vellum.workflows.workflows.base import BaseWorkflow
16
-
17
-
18
- class WorkflowEventType(Enum):
19
- NODE = "NODE"
20
- WORKFLOW = "WORKFLOW"
21
-
22
13
 
23
14
  def default_datetime_factory() -> datetime:
24
15
  """
@@ -47,9 +38,25 @@ def default_serializer(obj: Any) -> Any:
47
38
  )
48
39
 
49
40
 
41
+ class CodeResourceDefinition(UniversalBaseModel):
42
+ name: str
43
+ module: List[str]
44
+
45
+ @staticmethod
46
+ def encode(obj: type) -> "CodeResourceDefinition":
47
+ return CodeResourceDefinition(**serialize_type_encoder(obj))
48
+
49
+
50
+ VellumCodeResourceDefinition = Annotated[
51
+ CodeResourceDefinition,
52
+ BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder(d))),
53
+ ]
54
+
55
+
50
56
  class BaseParentContext(UniversalBaseModel):
51
57
  span_id: UUID
52
- parent: Optional['ParentContext'] = None
58
+ parent: Optional["ParentContext"] = None
59
+ type: str
53
60
 
54
61
 
55
62
  class BaseDeploymentParentContext(BaseParentContext):
@@ -73,29 +80,28 @@ class PromptDeploymentParentContext(BaseDeploymentParentContext):
73
80
 
74
81
  class NodeParentContext(BaseParentContext):
75
82
  type: Literal["WORKFLOW_NODE"] = "WORKFLOW_NODE"
76
- node_definition: Type['BaseNode']
77
-
78
- @field_serializer("node_definition")
79
- def serialize_node_definition(self, definition: Type, _info: Any) -> Dict[str, Any]:
80
- return serialize_type_encoder(definition)
83
+ node_definition: VellumCodeResourceDefinition
81
84
 
82
85
 
83
86
  class WorkflowParentContext(BaseParentContext):
84
87
  type: Literal["WORKFLOW"] = "WORKFLOW"
85
- workflow_definition: Type['BaseWorkflow']
86
-
87
- @field_serializer("workflow_definition")
88
- def serialize_workflow_definition(self, definition: Type, _info: Any) -> Dict[str, Any]:
89
- return serialize_type_encoder(definition)
90
-
91
-
92
- ParentContext = Union[
93
- NodeParentContext,
94
- WorkflowParentContext,
95
- PromptDeploymentParentContext,
96
- WorkflowDeploymentParentContext,
88
+ workflow_definition: VellumCodeResourceDefinition
89
+
90
+
91
+ # Define the discriminated union
92
+ ParentContext = Annotated[
93
+ Union[
94
+ WorkflowParentContext,
95
+ NodeParentContext,
96
+ WorkflowDeploymentParentContext,
97
+ PromptDeploymentParentContext,
98
+ ],
99
+ Field(discriminator="type"),
97
100
  ]
98
101
 
102
+ # Update the forward references
103
+ BaseParentContext.model_rebuild()
104
+
99
105
 
100
106
  class BaseEvent(UniversalBaseModel):
101
107
  id: UUID = Field(default_factory=uuid4)
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, Iterable, Liter
3
3
  from pydantic import field_serializer
4
4
 
5
5
  from vellum.core.pydantic_utilities import UniversalBaseModel
6
-
7
6
  from vellum.workflows.errors import VellumError
8
7
  from vellum.workflows.outputs.base import BaseOutput
9
8
  from vellum.workflows.references import ExternalInputReference
@@ -31,6 +30,14 @@ class _BaseWorkflowExecutionBody(UniversalBaseModel):
31
30
  return serialize_type_encoder(workflow_definition)
32
31
 
33
32
 
33
+ class _BaseWorkflowEvent(BaseEvent):
34
+ body: _BaseWorkflowExecutionBody
35
+
36
+ @property
37
+ def workflow_definition(self) -> Type["BaseWorkflow"]:
38
+ return self.body.workflow_definition
39
+
40
+
34
41
  class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[WorkflowInputsType]):
35
42
  inputs: WorkflowInputsType
36
43
 
@@ -39,7 +46,7 @@ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[Workflo
39
46
  return default_serializer(inputs)
40
47
 
41
48
 
42
- class WorkflowExecutionInitiatedEvent(BaseEvent, Generic[WorkflowInputsType]):
49
+ class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[WorkflowInputsType]):
43
50
  name: Literal["workflow.execution.initiated"] = "workflow.execution.initiated"
44
51
  body: WorkflowExecutionInitiatedBody[WorkflowInputsType]
45
52
 
@@ -56,7 +63,7 @@ class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
56
63
  return default_serializer(output)
57
64
 
58
65
 
59
- class WorkflowExecutionStreamingEvent(BaseEvent):
66
+ class WorkflowExecutionStreamingEvent(_BaseWorkflowEvent):
60
67
  name: Literal["workflow.execution.streaming"] = "workflow.execution.streaming"
61
68
  body: WorkflowExecutionStreamingBody
62
69
 
@@ -73,7 +80,7 @@ class WorkflowExecutionFulfilledBody(_BaseWorkflowExecutionBody, Generic[Outputs
73
80
  return default_serializer(outputs)
74
81
 
75
82
 
76
- class WorkflowExecutionFulfilledEvent(BaseEvent, Generic[OutputsType]):
83
+ class WorkflowExecutionFulfilledEvent(_BaseWorkflowEvent, Generic[OutputsType]):
77
84
  name: Literal["workflow.execution.fulfilled"] = "workflow.execution.fulfilled"
78
85
  body: WorkflowExecutionFulfilledBody[OutputsType]
79
86
 
@@ -86,7 +93,7 @@ class WorkflowExecutionRejectedBody(_BaseWorkflowExecutionBody):
86
93
  error: VellumError
87
94
 
88
95
 
89
- class WorkflowExecutionRejectedEvent(BaseEvent):
96
+ class WorkflowExecutionRejectedEvent(_BaseWorkflowEvent):
90
97
  name: Literal["workflow.execution.rejected"] = "workflow.execution.rejected"
91
98
  body: WorkflowExecutionRejectedBody
92
99
 
@@ -99,7 +106,7 @@ class WorkflowExecutionPausedBody(_BaseWorkflowExecutionBody):
99
106
  external_inputs: Iterable[ExternalInputReference]
100
107
 
101
108
 
102
- class WorkflowExecutionPausedEvent(BaseEvent):
109
+ class WorkflowExecutionPausedEvent(_BaseWorkflowEvent):
103
110
  name: Literal["workflow.execution.paused"] = "workflow.execution.paused"
104
111
  body: WorkflowExecutionPausedBody
105
112
 
@@ -112,7 +119,7 @@ class WorkflowExecutionResumedBody(_BaseWorkflowExecutionBody):
112
119
  pass
113
120
 
114
121
 
115
- class WorkflowExecutionResumedEvent(BaseEvent):
122
+ class WorkflowExecutionResumedEvent(_BaseWorkflowEvent):
116
123
  name: Literal["workflow.execution.resumed"] = "workflow.execution.resumed"
117
124
  body: WorkflowExecutionResumedBody
118
125
 
@@ -1,13 +1,15 @@
1
1
  from functools import cached_property, reduce
2
2
  import inspect
3
3
  from types import MappingProxyType
4
+ from uuid import UUID
4
5
  from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union, cast, get_args
5
6
 
6
7
  from vellum.workflows.constants import UNDEF
7
8
  from vellum.workflows.descriptors.base import BaseDescriptor
8
- from vellum.workflows.descriptors.utils import resolve_value
9
+ from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
9
10
  from vellum.workflows.edges.edge import Edge
10
11
  from vellum.workflows.errors.types import VellumErrorCode
12
+ from vellum.workflows.events.types import ParentContext
11
13
  from vellum.workflows.exceptions import NodeException
12
14
  from vellum.workflows.graph import Graph
13
15
  from vellum.workflows.graph.graph import GraphTarget
@@ -30,9 +32,15 @@ def is_nested_class(nested: Any, parent: Type) -> bool:
30
32
  inspect.isclass(nested)
31
33
  # If a class is defined within a function, we don't consider it nested in the class defining that function
32
34
  # The example of this is a Subworkflow defined within TryNode.wrap()
33
- and (len(nested.__qualname__.split(".")) < 2 or nested.__qualname__.split(".")[-2] != "<locals>")
35
+ and (
36
+ len(nested.__qualname__.split(".")) < 2
37
+ or nested.__qualname__.split(".")[-2] != "<locals>"
38
+ )
34
39
  and nested.__module__ == parent.__module__
35
- and (nested.__qualname__.startswith(parent.__name__) or nested.__qualname__.startswith(parent.__qualname__))
40
+ and (
41
+ nested.__qualname__.startswith(parent.__name__)
42
+ or nested.__qualname__.startswith(parent.__qualname__)
43
+ )
36
44
  ) or any(is_nested_class(nested, base) for base in parent.__bases__)
37
45
 
38
46
 
@@ -44,7 +52,11 @@ class BaseNodeMeta(type):
44
52
  if "Outputs" not in dct:
45
53
  for base in reversed(bases):
46
54
  if hasattr(base, "Outputs"):
47
- dct["Outputs"] = type(f"{name}.Outputs", (base.Outputs,), {"__module__": dct["__module__"]})
55
+ dct["Outputs"] = type(
56
+ f"{name}.Outputs",
57
+ (base.Outputs,),
58
+ {"__module__": dct["__module__"]},
59
+ )
48
60
  break
49
61
  else:
50
62
  raise ValueError("Outputs class not found in base classes")
@@ -66,14 +78,23 @@ class BaseNodeMeta(type):
66
78
  if "Execution" not in dct:
67
79
  for base in reversed(bases):
68
80
  if issubclass(base, BaseNode):
69
- dct["Execution"] = type(f"{name}.Execution", (base.Execution,), {"__module__": dct["__module__"]})
81
+ dct["Execution"] = type(
82
+ f"{name}.Execution",
83
+ (base.Execution,),
84
+ {"__module__": dct["__module__"]},
85
+ )
70
86
  break
71
87
 
72
88
  if "Trigger" not in dct:
73
89
  for base in reversed(bases):
74
90
  if issubclass(base, BaseNode):
75
- trigger_dct = {**base.Trigger.__dict__, "__module__": dct["__module__"]}
76
- dct["Trigger"] = type(f"{name}.Trigger", (base.Trigger,), trigger_dct)
91
+ trigger_dct = {
92
+ **base.Trigger.__dict__,
93
+ "__module__": dct["__module__"],
94
+ }
95
+ dct["Trigger"] = type(
96
+ f"{name}.Trigger", (base.Trigger,), trigger_dct
97
+ )
77
98
  break
78
99
 
79
100
  cls = super().__new__(mcs, name, bases, dct)
@@ -118,7 +139,9 @@ class BaseNodeMeta(type):
118
139
 
119
140
  def __rshift__(cls, other_cls: GraphTarget) -> Graph:
120
141
  if not issubclass(cls, BaseNode):
121
- raise ValueError("BaseNodeMeta can only be extended from subclasses of BaseNode")
142
+ raise ValueError(
143
+ "BaseNodeMeta can only be extended from subclasses of BaseNode"
144
+ )
122
145
 
123
146
  if not cls.Ports._default_port:
124
147
  raise ValueError("No default port found on node")
@@ -130,7 +153,9 @@ class BaseNodeMeta(type):
130
153
 
131
154
  def __rrshift__(cls, other_cls: GraphTarget) -> Graph:
132
155
  if not issubclass(cls, BaseNode):
133
- raise ValueError("BaseNodeMeta can only be extended from subclasses of BaseNode")
156
+ raise ValueError(
157
+ "BaseNodeMeta can only be extended from subclasses of BaseNode"
158
+ )
134
159
 
135
160
  if not isinstance(other_cls, set):
136
161
  other_cls = {other_cls}
@@ -168,13 +193,18 @@ class _BaseNodeTriggerMeta(type):
168
193
  if not isinstance(other, _BaseNodeTriggerMeta):
169
194
  return False
170
195
 
171
- if not self.__name__.endswith(".Trigger") or not other.__name__.endswith(".Trigger"):
196
+ if not self.__name__.endswith(".Trigger") or not other.__name__.endswith(
197
+ ".Trigger"
198
+ ):
172
199
  return super().__eq__(other)
173
200
 
174
201
  self_trigger_class = cast(Type["BaseNode.Trigger"], self)
175
202
  other_trigger_class = cast(Type["BaseNode.Trigger"], other)
176
203
 
177
- return self_trigger_class.node_class.__name__ == other_trigger_class.node_class.__name__
204
+ return (
205
+ self_trigger_class.node_class.__name__
206
+ == other_trigger_class.node_class.__name__
207
+ )
178
208
 
179
209
 
180
210
  class _BaseNodeExecutionMeta(type):
@@ -192,13 +222,18 @@ class _BaseNodeExecutionMeta(type):
192
222
  if not isinstance(other, _BaseNodeExecutionMeta):
193
223
  return False
194
224
 
195
- if not self.__name__.endswith(".Execution") or not other.__name__.endswith(".Execution"):
225
+ if not self.__name__.endswith(".Execution") or not other.__name__.endswith(
226
+ ".Execution"
227
+ ):
196
228
  return super().__eq__(other)
197
229
 
198
230
  self_execution_class = cast(Type["BaseNode.Execution"], self)
199
231
  other_execution_class = cast(Type["BaseNode.Execution"], other)
200
232
 
201
- return self_execution_class.node_class.__name__ == other_execution_class.node_class.__name__
233
+ return (
234
+ self_execution_class.node_class.__name__
235
+ == other_execution_class.node_class.__name__
236
+ )
202
237
 
203
238
 
204
239
  class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
@@ -225,55 +260,78 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
225
260
 
226
261
  @classmethod
227
262
  def should_initiate(
228
- cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: "Optional[Edge]" = None
263
+ cls,
264
+ state: StateType,
265
+ dependencies: Set["Type[BaseNode]"],
266
+ node_span_id: UUID,
229
267
  ) -> bool:
230
268
  """
231
269
  Determines whether a Node's execution should be initiated. Override this method to define custom
232
270
  trigger criteria.
233
271
  """
234
272
 
235
- if cls.merge_behavior == MergeBehavior.AWAIT_ANY:
236
- if not invoked_by:
237
- return True
238
-
239
- is_ready = not state.meta.node_execution_cache.is_node_initiated(cls.node_class)
273
+ if cls.merge_behavior == MergeBehavior.AWAIT_ATTRIBUTES:
274
+ if state.meta.node_execution_cache.is_node_execution_initiated(
275
+ cls.node_class, node_span_id
276
+ ):
277
+ return False
240
278
 
241
- invoked_identifier = str(invoked_by.from_port.node_class)
242
- node_identifier = str(cls.node_class)
279
+ is_ready = True
280
+ for descriptor in cls.node_class:
281
+ if not descriptor.instance:
282
+ continue
243
283
 
244
- dependencies_invoked = state.meta.node_execution_cache.dependencies_invoked[node_identifier]
245
- dependencies_invoked.add(invoked_identifier)
246
- if all(str(dep) in dependencies_invoked for dep in dependencies):
247
- del state.meta.node_execution_cache.dependencies_invoked[node_identifier]
284
+ resolved_value = resolve_value(
285
+ descriptor.instance, state, path=descriptor.name
286
+ )
287
+ if is_unresolved(resolved_value):
288
+ is_ready = False
289
+ break
248
290
 
249
291
  return is_ready
250
292
 
251
- if cls.merge_behavior == MergeBehavior.AWAIT_ALL:
252
- if not invoked_by:
253
- return True
293
+ if cls.merge_behavior == MergeBehavior.AWAIT_ANY:
294
+ if state.meta.node_execution_cache.is_node_execution_initiated(
295
+ cls.node_class, node_span_id
296
+ ):
297
+ return False
298
+
299
+ return True
254
300
 
255
- if state.meta.node_execution_cache.is_node_initiated(cls.node_class):
301
+ if cls.merge_behavior == MergeBehavior.AWAIT_ALL:
302
+ if state.meta.node_execution_cache.is_node_execution_initiated(
303
+ cls.node_class, node_span_id
304
+ ):
256
305
  return False
257
306
 
258
307
  """
259
308
  A node utilizing an AWAIT_ALL merge strategy will only be considered ready for the Nth time
260
309
  when all of its dependencies have been executed N times.
261
310
  """
262
- current_node_execution_count = state.meta.node_execution_cache.get_execution_count(cls.node_class)
263
- is_ready_outcome = all(
264
- state.meta.node_execution_cache.get_execution_count(dep) == current_node_execution_count + 1
311
+ current_node_execution_count = (
312
+ state.meta.node_execution_cache.get_execution_count(cls.node_class)
313
+ )
314
+ return all(
315
+ state.meta.node_execution_cache.get_execution_count(dep)
316
+ == current_node_execution_count + 1
265
317
  for dep in dependencies
266
318
  )
267
319
 
268
- return is_ready_outcome
269
-
270
- raise NodeException(message="Invalid Trigger Node Specification", code=VellumErrorCode.INVALID_INPUTS)
320
+ raise NodeException(
321
+ message="Invalid Trigger Node Specification",
322
+ code=VellumErrorCode.INVALID_INPUTS,
323
+ )
271
324
 
272
325
  class Execution(metaclass=_BaseNodeExecutionMeta):
273
326
  node_class: Type["BaseNode"]
274
327
  count: int
275
328
 
276
- def __init__(self, *, state: Optional[StateType] = None, context: Optional[WorkflowContext] = None):
329
+ def __init__(
330
+ self,
331
+ *,
332
+ state: Optional[StateType] = None,
333
+ context: Optional[WorkflowContext] = None,
334
+ ):
277
335
  if state:
278
336
  self.state = state
279
337
  else:
@@ -295,7 +353,9 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
295
353
  if not descriptor.instance:
296
354
  continue
297
355
 
298
- resolved_value = resolve_value(descriptor.instance, self.state, path=descriptor.name, memo=inputs)
356
+ resolved_value = resolve_value(
357
+ descriptor.instance, self.state, path=descriptor.name, memo=inputs
358
+ )
299
359
  setattr(self, descriptor.name, resolved_value)
300
360
 
301
361
  # Resolve descriptors set as defaults to the outputs class
@@ -319,7 +379,9 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
319
379
  for key, value in inputs.items():
320
380
  path_parts = key.split(".")
321
381
  node_attribute_discriptor = getattr(self.__class__, path_parts[0])
322
- inputs_key = reduce(lambda acc, part: acc[part], path_parts[1:], node_attribute_discriptor)
382
+ inputs_key = reduce(
383
+ lambda acc, part: acc[part], path_parts[1:], node_attribute_discriptor
384
+ )
323
385
  all_inputs[inputs_key] = value
324
386
 
325
387
  self._inputs = MappingProxyType(all_inputs)
@@ -1,5 +1,5 @@
1
1
  import sys
2
- from types import ModuleType
2
+ from types import MappingProxyType, ModuleType
3
3
  from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, cast
4
4
 
5
5
  from vellum.workflows.errors.types import VellumError, VellumErrorCode
@@ -8,7 +8,9 @@ from vellum.workflows.nodes.bases import BaseNode
8
8
  from vellum.workflows.nodes.bases.base import BaseNodeMeta
9
9
  from vellum.workflows.nodes.utils import ADORNMENT_MODULE_NAME
10
10
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
11
+ from vellum.workflows.state.context import WorkflowContext
11
12
  from vellum.workflows.types.generics import StateType
13
+ from vellum.workflows.workflows.event_filters import all_workflow_event_filter
12
14
 
13
15
  if TYPE_CHECKING:
14
16
  from vellum.workflows import BaseWorkflow
@@ -44,6 +46,14 @@ class _TryNodeMeta(BaseNodeMeta):
44
46
 
45
47
  return node_class
46
48
 
49
+ def __getattribute__(cls, name: str) -> Any:
50
+ try:
51
+ return super().__getattribute__(name)
52
+ except AttributeError:
53
+ if name != "__wrapped_node__" and issubclass(cls, TryNode):
54
+ return getattr(cls.__wrapped_node__, name)
55
+ raise
56
+
47
57
 
48
58
  class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
49
59
  """
@@ -53,6 +63,7 @@ class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
53
63
  subworkflow: Type["BaseWorkflow"] - The Subworkflow to execute
54
64
  """
55
65
 
66
+ __wrapped_node__: Optional[Type["BaseNode"]] = None
56
67
  on_error_code: Optional[VellumErrorCode] = None
57
68
  subworkflow: Type["BaseWorkflow"]
58
69
 
@@ -62,15 +73,20 @@ class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
62
73
  def run(self) -> Iterator[BaseOutput]:
63
74
  subworkflow = self.subworkflow(
64
75
  parent_state=self.state,
65
- context=self._context,
76
+ context=WorkflowContext(
77
+ _vellum_client=self._context._vellum_client,
78
+ ),
79
+ )
80
+ subworkflow_stream = subworkflow.stream(
81
+ event_filter=all_workflow_event_filter,
66
82
  )
67
- subworkflow_stream = subworkflow.stream()
68
83
 
69
84
  outputs: Optional[BaseOutputs] = None
70
85
  exception: Optional[NodeException] = None
71
86
  fulfilled_output_names: Set[str] = set()
72
87
 
73
88
  for event in subworkflow_stream:
89
+ self._context._emit_subworkflow_event(event)
74
90
  if exception:
75
91
  continue
76
92
 
@@ -122,8 +138,9 @@ Message: {event.error.message}""",
122
138
  # https://app.shortcut.com/vellum/story/4116
123
139
  from vellum.workflows import BaseWorkflow
124
140
 
141
+ inner_cls._is_wrapped_node = True
142
+
125
143
  class Subworkflow(BaseWorkflow):
126
- inner_cls._is_wrapped_node = True
127
144
  graph = inner_cls
128
145
 
129
146
  # mypy is wrong here, this works and is defined
@@ -139,6 +156,7 @@ Message: {event.error.message}""",
139
156
  cls.__name__,
140
157
  (TryNode,),
141
158
  {
159
+ "__wrapped_node__": inner_cls,
142
160
  "__module__": dynamic_module,
143
161
  "on_error_code": _on_error_code,
144
162
  "subworkflow": Subworkflow,
@@ -111,3 +111,18 @@ def test_try_node__use_parent_execution_context():
111
111
  # THEN the inner node had access to the key
112
112
  assert len(outputs) == 1
113
113
  assert outputs[-1] == BaseOutput(name="key", value="test-key")
114
+
115
+
116
+ def test_try_node__resolved_inputs():
117
+ """
118
+ This test ensures that node attributes of TryNodes are correctly resolved.
119
+ """
120
+
121
+ class State(BaseState):
122
+ counter = 3.0
123
+
124
+ @TryNode.wrap()
125
+ class MyNode(BaseNode[State]):
126
+ foo = State.counter
127
+
128
+ assert MyNode.foo.types == (float,)