vellum-ai 0.10.4__py3-none-any.whl → 0.10.7__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (76) hide show
  1. vellum/__init__.py +2 -0
  2. vellum/client/README.md +7 -52
  3. vellum/client/__init__.py +16 -136
  4. vellum/client/core/client_wrapper.py +1 -1
  5. vellum/client/resources/ad_hoc/client.py +14 -104
  6. vellum/client/resources/metric_definitions/client.py +113 -0
  7. vellum/client/resources/test_suites/client.py +8 -16
  8. vellum/client/resources/workflows/client.py +0 -32
  9. vellum/client/types/__init__.py +2 -0
  10. vellum/client/types/metric_definition_history_item.py +39 -0
  11. vellum/types/metric_definition_history_item.py +3 -0
  12. vellum/workflows/events/node.py +36 -3
  13. vellum/workflows/events/tests/test_event.py +89 -9
  14. vellum/workflows/nodes/__init__.py +6 -7
  15. vellum/workflows/nodes/bases/base.py +0 -1
  16. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -1
  17. vellum/workflows/nodes/core/templating_node/node.py +5 -1
  18. vellum/workflows/nodes/core/try_node/node.py +65 -27
  19. vellum/workflows/nodes/core/try_node/tests/test_node.py +17 -10
  20. vellum/workflows/nodes/displayable/__init__.py +2 -0
  21. vellum/workflows/nodes/displayable/bases/api_node/node.py +3 -3
  22. vellum/workflows/nodes/displayable/code_execution_node/node.py +5 -2
  23. vellum/workflows/nodes/displayable/conditional_node/node.py +2 -2
  24. vellum/workflows/nodes/displayable/final_output_node/node.py +6 -2
  25. vellum/workflows/nodes/displayable/note_node/__init__.py +5 -0
  26. vellum/workflows/nodes/displayable/note_node/node.py +10 -0
  27. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +10 -11
  28. vellum/workflows/nodes/utils.py +2 -0
  29. vellum/workflows/outputs/base.py +26 -2
  30. vellum/workflows/ports/node_ports.py +2 -2
  31. vellum/workflows/ports/port.py +14 -0
  32. vellum/workflows/references/__init__.py +2 -0
  33. vellum/workflows/runner/runner.py +46 -33
  34. vellum/workflows/runner/types.py +1 -3
  35. vellum/workflows/state/encoder.py +2 -1
  36. vellum/workflows/types/tests/test_utils.py +15 -3
  37. vellum/workflows/types/utils.py +4 -1
  38. vellum/workflows/utils/vellum_variables.py +13 -1
  39. vellum/workflows/workflows/base.py +24 -1
  40. {vellum_ai-0.10.4.dist-info → vellum_ai-0.10.7.dist-info}/METADATA +8 -6
  41. {vellum_ai-0.10.4.dist-info → vellum_ai-0.10.7.dist-info}/RECORD +76 -69
  42. vellum_cli/CONTRIBUTING.md +66 -0
  43. vellum_cli/README.md +3 -0
  44. vellum_ee/workflows/display/base.py +2 -1
  45. vellum_ee/workflows/display/nodes/base_node_display.py +27 -4
  46. vellum_ee/workflows/display/nodes/vellum/__init__.py +2 -0
  47. vellum_ee/workflows/display/nodes/vellum/api_node.py +3 -3
  48. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +4 -4
  49. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +86 -41
  50. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +4 -2
  51. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +3 -3
  52. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +4 -5
  53. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +9 -9
  54. vellum_ee/workflows/display/nodes/vellum/map_node.py +23 -51
  55. vellum_ee/workflows/display/nodes/vellum/note_node.py +32 -0
  56. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +5 -5
  57. vellum_ee/workflows/display/nodes/vellum/search_node.py +1 -1
  58. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +2 -2
  59. vellum_ee/workflows/display/nodes/vellum/templating_node.py +1 -1
  60. vellum_ee/workflows/display/nodes/vellum/try_node.py +16 -4
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +7 -3
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +122 -107
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +6 -5
  64. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +77 -64
  65. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +15 -11
  66. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +6 -6
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +6 -6
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +4 -3
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +7 -6
  70. vellum_ee/workflows/display/utils/vellum.py +3 -2
  71. vellum_ee/workflows/display/workflows/base_workflow_display.py +14 -9
  72. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +2 -7
  73. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +18 -16
  74. {vellum_ai-0.10.4.dist-info → vellum_ai-0.10.7.dist-info}/LICENSE +0 -0
  75. {vellum_ai-0.10.4.dist-info → vellum_ai-0.10.7.dist-info}/WHEEL +0 -0
  76. {vellum_ai-0.10.4.dist-info → vellum_ai-0.10.7.dist-info}/entry_points.txt +0 -0
@@ -232,25 +232,21 @@ class TestSuitesClient:
232
232
  api_key="YOUR_API_KEY",
233
233
  )
234
234
  response = client.test_suites.test_suite_test_cases_bulk(
235
- id="string",
235
+ id="id",
236
236
  request=[
237
237
  TestSuiteTestCaseCreateBulkOperationRequest(
238
- id="string",
238
+ id="id",
239
239
  data=CreateTestSuiteTestCaseRequest(
240
- label="string",
241
240
  input_values=[
242
241
  NamedTestCaseStringVariableValueRequest(
243
- value="string",
244
- name="string",
242
+ name="name",
245
243
  )
246
244
  ],
247
245
  evaluation_values=[
248
246
  NamedTestCaseStringVariableValueRequest(
249
- value="string",
250
- name="string",
247
+ name="name",
251
248
  )
252
249
  ],
253
- external_id="string",
254
250
  ),
255
251
  )
256
252
  ],
@@ -571,25 +567,21 @@ class AsyncTestSuitesClient:
571
567
 
572
568
  async def main() -> None:
573
569
  response = await client.test_suites.test_suite_test_cases_bulk(
574
- id="string",
570
+ id="id",
575
571
  request=[
576
572
  TestSuiteTestCaseCreateBulkOperationRequest(
577
- id="string",
573
+ id="id",
578
574
  data=CreateTestSuiteTestCaseRequest(
579
- label="string",
580
575
  input_values=[
581
576
  NamedTestCaseStringVariableValueRequest(
582
- value="string",
583
- name="string",
577
+ name="name",
584
578
  )
585
579
  ],
586
580
  evaluation_values=[
587
581
  NamedTestCaseStringVariableValueRequest(
588
- value="string",
589
- name="string",
582
+ name="name",
590
583
  )
591
584
  ],
592
- external_id="string",
593
585
  ),
594
586
  )
595
587
  ],
@@ -47,18 +47,6 @@ class WorkflowsClient:
47
47
  ------
48
48
  typing.Iterator[bytes]
49
49
 
50
-
51
- Examples
52
- --------
53
- from vellum import Vellum
54
-
55
- client = Vellum(
56
- api_key="YOUR_API_KEY",
57
- )
58
- client.workflows.pull(
59
- id="string",
60
- format="json",
61
- )
62
50
  """
63
51
  with self._client_wrapper.httpx_client.stream(
64
52
  f"v1/workflows/{jsonable_encoder(id)}/pull",
@@ -196,26 +184,6 @@ class AsyncWorkflowsClient:
196
184
  ------
197
185
  typing.AsyncIterator[bytes]
198
186
 
199
-
200
- Examples
201
- --------
202
- import asyncio
203
-
204
- from vellum import AsyncVellum
205
-
206
- client = AsyncVellum(
207
- api_key="YOUR_API_KEY",
208
- )
209
-
210
-
211
- async def main() -> None:
212
- await client.workflows.pull(
213
- id="string",
214
- format="json",
215
- )
216
-
217
-
218
- asyncio.run(main())
219
187
  """
220
188
  async with self._client_wrapper.httpx_client.stream(
221
189
  f"v1/workflows/{jsonable_encoder(id)}/pull",
@@ -197,6 +197,7 @@ from .metadata_filter_rule_combinator import MetadataFilterRuleCombinator
197
197
  from .metadata_filter_rule_request import MetadataFilterRuleRequest
198
198
  from .metadata_filters_request import MetadataFiltersRequest
199
199
  from .metric_definition_execution import MetricDefinitionExecution
200
+ from .metric_definition_history_item import MetricDefinitionHistoryItem
200
201
  from .metric_definition_input import MetricDefinitionInput
201
202
  from .metric_node_result import MetricNodeResult
202
203
  from .ml_model_read import MlModelRead
@@ -685,6 +686,7 @@ __all__ = [
685
686
  "MetadataFilterRuleRequest",
686
687
  "MetadataFiltersRequest",
687
688
  "MetricDefinitionExecution",
689
+ "MetricDefinitionHistoryItem",
688
690
  "MetricDefinitionInput",
689
691
  "MetricNodeResult",
690
692
  "MlModelRead",
@@ -0,0 +1,39 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from __future__ import annotations
4
+ from ..core.pydantic_utilities import UniversalBaseModel
5
+ from .array_vellum_value import ArrayVellumValue
6
+ import pydantic
7
+ import typing
8
+ from .vellum_variable import VellumVariable
9
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
10
+ from ..core.pydantic_utilities import update_forward_refs
11
+
12
+
13
+ class MetricDefinitionHistoryItem(UniversalBaseModel):
14
+ id: str
15
+ label: str = pydantic.Field()
16
+ """
17
+ A human-readable label for the metric
18
+ """
19
+
20
+ name: str = pydantic.Field()
21
+ """
22
+ A name that uniquely identifies this metric within its workspace
23
+ """
24
+
25
+ description: str
26
+ input_variables: typing.List[VellumVariable]
27
+ output_variables: typing.List[VellumVariable]
28
+
29
+ if IS_PYDANTIC_V2:
30
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
31
+ else:
32
+
33
+ class Config:
34
+ frozen = True
35
+ smart_union = True
36
+ extra = pydantic.Extra.allow
37
+
38
+
39
+ update_forward_refs(ArrayVellumValue, MetricDefinitionHistoryItem=MetricDefinitionHistoryItem)
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.metric_definition_history_item import *
@@ -1,13 +1,14 @@
1
- from typing import Any, Dict, Generic, Literal, Type, Union
1
+ from typing import Any, Dict, Generic, Iterable, List, Literal, Optional, Set, Type, Union
2
2
 
3
- from pydantic import field_serializer
3
+ from pydantic import ConfigDict, SerializerFunctionWrapHandler, field_serializer, model_serializer
4
+ from pydantic.main import IncEx
4
5
 
5
6
  from vellum.core.pydantic_utilities import UniversalBaseModel
6
-
7
7
  from vellum.workflows.errors import VellumError
8
8
  from vellum.workflows.expressions.accessor import AccessorExpression
9
9
  from vellum.workflows.nodes.bases import BaseNode
10
10
  from vellum.workflows.outputs.base import BaseOutput
11
+ from vellum.workflows.ports.port import Port
11
12
  from vellum.workflows.references.node import NodeReference
12
13
  from vellum.workflows.types.generics import OutputsType
13
14
 
@@ -21,6 +22,15 @@ class _BaseNodeExecutionBody(UniversalBaseModel):
21
22
  def serialize_node_definition(self, node_definition: Type, _info: Any) -> Dict[str, Any]:
22
23
  return serialize_type_encoder(node_definition)
23
24
 
25
+ # Couldn't get this to work with model_config.exclude_none or model_config.exclude_defaults
26
+ # so we're excluding null invoked_ports manually here for now
27
+ @model_serializer(mode="wrap", when_used="json")
28
+ def serialize_model(self, handler: SerializerFunctionWrapHandler) -> Any:
29
+ serialized = super().serialize_model(handler) # type: ignore[call-arg, arg-type]
30
+ if "invoked_ports" in serialized and serialized["invoked_ports"] is None:
31
+ del serialized["invoked_ports"]
32
+ return serialized
33
+
24
34
 
25
35
  class _BaseNodeEvent(BaseEvent):
26
36
  body: _BaseNodeExecutionBody
@@ -31,6 +41,7 @@ class _BaseNodeEvent(BaseEvent):
31
41
 
32
42
 
33
43
  NodeInputName = Union[NodeReference, AccessorExpression]
44
+ InvokedPorts = Optional[Set["Port"]]
34
45
 
35
46
 
36
47
  class NodeExecutionInitiatedBody(_BaseNodeExecutionBody):
@@ -52,11 +63,18 @@ class NodeExecutionInitiatedEvent(_BaseNodeEvent):
52
63
 
53
64
  class NodeExecutionStreamingBody(_BaseNodeExecutionBody):
54
65
  output: BaseOutput
66
+ invoked_ports: InvokedPorts = None
55
67
 
56
68
  @field_serializer("output")
57
69
  def serialize_output(self, output: BaseOutput, _info: Any) -> Dict[str, Any]:
58
70
  return default_serializer(output)
59
71
 
72
+ @field_serializer("invoked_ports")
73
+ def serialize_invoked_ports(self, invoked_ports: InvokedPorts, _info: Any) -> Optional[List[Dict[str, Any]]]:
74
+ if not invoked_ports:
75
+ return None
76
+ return [default_serializer(port) for port in invoked_ports]
77
+
60
78
 
61
79
  class NodeExecutionStreamingEvent(_BaseNodeEvent):
62
80
  name: Literal["node.execution.streaming"] = "node.execution.streaming"
@@ -66,14 +84,25 @@ class NodeExecutionStreamingEvent(_BaseNodeEvent):
66
84
  def output(self) -> BaseOutput:
67
85
  return self.body.output
68
86
 
87
+ @property
88
+ def invoked_ports(self) -> InvokedPorts:
89
+ return self.body.invoked_ports
90
+
69
91
 
70
92
  class NodeExecutionFulfilledBody(_BaseNodeExecutionBody, Generic[OutputsType]):
71
93
  outputs: OutputsType
94
+ invoked_ports: InvokedPorts = None
72
95
 
73
96
  @field_serializer("outputs")
74
97
  def serialize_outputs(self, outputs: OutputsType, _info: Any) -> Dict[str, Any]:
75
98
  return default_serializer(outputs)
76
99
 
100
+ @field_serializer("invoked_ports")
101
+ def serialize_invoked_ports(self, invoked_ports: InvokedPorts, _info: Any) -> Optional[List[Dict[str, Any]]]:
102
+ if invoked_ports is None:
103
+ return None
104
+ return [default_serializer(port) for port in invoked_ports]
105
+
77
106
 
78
107
  class NodeExecutionFulfilledEvent(_BaseNodeEvent, Generic[OutputsType]):
79
108
  name: Literal["node.execution.fulfilled"] = "node.execution.fulfilled"
@@ -83,6 +112,10 @@ class NodeExecutionFulfilledEvent(_BaseNodeEvent, Generic[OutputsType]):
83
112
  def outputs(self) -> OutputsType:
84
113
  return self.body.outputs
85
114
 
115
+ @property
116
+ def invoked_ports(self) -> InvokedPorts:
117
+ return self.body.invoked_ports
118
+
86
119
 
87
120
  class NodeExecutionRejectedBody(_BaseNodeExecutionBody):
88
121
  error: VellumError
@@ -6,7 +6,14 @@ from uuid import UUID
6
6
  from deepdiff import DeepDiff
7
7
 
8
8
  from vellum.workflows.errors.types import VellumError, VellumErrorCode
9
- from vellum.workflows.events.node import NodeExecutionInitiatedBody, NodeExecutionInitiatedEvent
9
+ from vellum.workflows.events.node import (
10
+ NodeExecutionFulfilledBody,
11
+ NodeExecutionFulfilledEvent,
12
+ NodeExecutionInitiatedBody,
13
+ NodeExecutionInitiatedEvent,
14
+ NodeExecutionStreamingBody,
15
+ NodeExecutionStreamingEvent,
16
+ )
10
17
  from vellum.workflows.events.types import NodeParentContext, WorkflowParentContext
11
18
  from vellum.workflows.events.workflow import (
12
19
  WorkflowExecutionFulfilledBody,
@@ -93,10 +100,9 @@ module_root = name_parts[: name_parts.index("events")]
93
100
  node_definition=MockNode,
94
101
  span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
95
102
  parent=WorkflowParentContext(
96
- workflow_definition=MockWorkflow,
97
- span_id=UUID("123e4567-e89b-12d3-a456-426614174000")
98
- )
99
- )
103
+ workflow_definition=MockWorkflow, span_id=UUID("123e4567-e89b-12d3-a456-426614174000")
104
+ ),
105
+ ),
100
106
  ),
101
107
  {
102
108
  "id": "123e4567-e89b-12d3-a456-426614174000",
@@ -126,10 +132,10 @@ module_root = name_parts[: name_parts.index("events")]
126
132
  },
127
133
  "type": "WORKFLOW",
128
134
  "parent": None,
129
- "span_id": "123e4567-e89b-12d3-a456-426614174000"
135
+ "span_id": "123e4567-e89b-12d3-a456-426614174000",
130
136
  },
131
137
  "type": "WORKFLOW_NODE",
132
- "span_id": "123e4567-e89b-12d3-a456-426614174000"
138
+ "span_id": "123e4567-e89b-12d3-a456-426614174000",
133
139
  },
134
140
  },
135
141
  ),
@@ -164,7 +170,7 @@ module_root = name_parts[: name_parts.index("events")]
164
170
  "value": "foo",
165
171
  },
166
172
  },
167
- "parent": None
173
+ "parent": None,
168
174
  },
169
175
  ),
170
176
  (
@@ -233,6 +239,78 @@ module_root = name_parts[: name_parts.index("events")]
233
239
  "parent": None,
234
240
  },
235
241
  ),
242
+ (
243
+ NodeExecutionStreamingEvent(
244
+ id=UUID("123e4567-e89b-12d3-a456-426614174000"),
245
+ timestamp=datetime(2024, 1, 1, 12, 0, 0),
246
+ trace_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
247
+ span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
248
+ body=NodeExecutionStreamingBody(
249
+ node_definition=MockNode,
250
+ output=BaseOutput(
251
+ name="example",
252
+ value="foo",
253
+ ),
254
+ ),
255
+ ),
256
+ {
257
+ "id": "123e4567-e89b-12d3-a456-426614174000",
258
+ "api_version": "2024-10-25",
259
+ "timestamp": "2024-01-01T12:00:00",
260
+ "trace_id": "123e4567-e89b-12d3-a456-426614174000",
261
+ "span_id": "123e4567-e89b-12d3-a456-426614174000",
262
+ "name": "node.execution.streaming",
263
+ "body": {
264
+ "node_definition": {
265
+ "name": "MockNode",
266
+ "module": module_root + ["events", "tests", "test_event"],
267
+ },
268
+ "output": {
269
+ "name": "example",
270
+ "value": "foo",
271
+ },
272
+ },
273
+ "parent": None,
274
+ },
275
+ ),
276
+ (
277
+ NodeExecutionFulfilledEvent(
278
+ id=UUID("123e4567-e89b-12d3-a456-426614174000"),
279
+ timestamp=datetime(2024, 1, 1, 12, 0, 0),
280
+ trace_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
281
+ span_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
282
+ body=NodeExecutionFulfilledBody(
283
+ node_definition=MockNode,
284
+ outputs=MockNode.Outputs(
285
+ example="foo",
286
+ ),
287
+ invoked_ports={MockNode.Ports.default},
288
+ ),
289
+ ),
290
+ {
291
+ "id": "123e4567-e89b-12d3-a456-426614174000",
292
+ "api_version": "2024-10-25",
293
+ "timestamp": "2024-01-01T12:00:00",
294
+ "trace_id": "123e4567-e89b-12d3-a456-426614174000",
295
+ "span_id": "123e4567-e89b-12d3-a456-426614174000",
296
+ "name": "node.execution.fulfilled",
297
+ "body": {
298
+ "node_definition": {
299
+ "name": "MockNode",
300
+ "module": module_root + ["events", "tests", "test_event"],
301
+ },
302
+ "outputs": {
303
+ "example": "foo",
304
+ },
305
+ "invoked_ports": [
306
+ {
307
+ "name": "default",
308
+ }
309
+ ],
310
+ },
311
+ "parent": None,
312
+ },
313
+ ),
236
314
  ],
237
315
  ids=[
238
316
  "workflow.execution.initiated",
@@ -240,7 +318,9 @@ module_root = name_parts[: name_parts.index("events")]
240
318
  "workflow.execution.streaming",
241
319
  "workflow.execution.fulfilled",
242
320
  "workflow.execution.rejected",
321
+ "node.execution.streaming",
322
+ "node.execution.fulfilled",
243
323
  ],
244
324
  )
245
325
  def test_event_serialization(event, expected_json):
246
- assert not DeepDiff(json.loads(event.model_dump_json()), expected_json)
326
+ assert not DeepDiff(event.model_dump(mode="json"), expected_json)
@@ -1,5 +1,5 @@
1
1
  from vellum.workflows.nodes.bases import BaseNode
2
- from vellum.workflows.nodes.core import (ErrorNode, InlineSubworkflowNode, MapNode, RetryNode, TemplatingNode, TryNode,)
2
+ from vellum.workflows.nodes.core import ErrorNode, InlineSubworkflowNode, MapNode, RetryNode, TemplatingNode, TryNode
3
3
  from vellum.workflows.nodes.displayable import (
4
4
  APINode,
5
5
  CodeExecutionNode,
@@ -7,6 +7,7 @@ from vellum.workflows.nodes.displayable import (
7
7
  FinalOutputNode,
8
8
  GuardrailNode,
9
9
  InlinePromptNode,
10
+ NoteNode,
10
11
  PromptDeploymentNode,
11
12
  SearchNode,
12
13
  SubworkflowDeploymentNode,
@@ -28,20 +29,18 @@ __all__ = [
28
29
  "TemplatingNode",
29
30
  "TryNode",
30
31
  # Displayable Base Nodes
31
- "BaseSearchNode",
32
32
  "BaseInlinePromptNode",
33
33
  "BasePromptDeploymentNode",
34
+ "BaseSearchNode",
34
35
  # Displayable Nodes
35
36
  "APINode",
36
37
  "CodeExecutionNode",
38
+ "ConditionalNode",
39
+ "FinalOutputNode",
37
40
  "GuardrailNode",
38
41
  "InlinePromptNode",
42
+ "NoteNode",
39
43
  "PromptDeploymentNode",
40
44
  "SearchNode",
41
- "ConditionalNode",
42
- "GuardrailNode",
43
45
  "SubworkflowDeploymentNode",
44
- "FinalOutputNode",
45
- "PromptDeploymentNode",
46
- "SearchNode",
47
46
  ]
@@ -215,7 +215,6 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
215
215
  # https://app.shortcut.com/vellum/story/4008/auto-inherit-basenodeoutputs-in-outputs-classes
216
216
  class Outputs(BaseOutputs):
217
217
  _node_class: Optional[Type["BaseNode"]] = None
218
- pass
219
218
 
220
219
  class Ports(NodePorts):
221
220
  default = Port(default=True)
@@ -57,7 +57,7 @@ class InlineSubworkflowNode(BaseSubworkflowNode[StateType], Generic[StateType, W
57
57
  if outputs is None:
58
58
  raise NodeException(
59
59
  message="Expected to receive outputs from Workflow Deployment",
60
- code=VellumErrorCode.INTERNAL_ERROR,
60
+ code=VellumErrorCode.INVALID_OUTPUTS,
61
61
  )
62
62
 
63
63
  # For any outputs somehow in our final fulfilled outputs array,
@@ -49,7 +49,11 @@ class _TemplatingNodeMeta(BaseNodeMeta):
49
49
  if not isinstance(parent, _TemplatingNodeMeta):
50
50
  raise ValueError("TemplatingNode must be created with the TemplatingNodeMeta metaclass")
51
51
 
52
- parent.__dict__["Outputs"].__annotations__["result"] = parent.get_output_type()
52
+ annotations = parent.__dict__["Outputs"].__annotations__
53
+ parent.__dict__["Outputs"].__annotations__ = {
54
+ **annotations,
55
+ "result": parent.get_output_type(),
56
+ }
53
57
  return parent
54
58
 
55
59
  def get_output_type(cls) -> Type:
@@ -1,10 +1,13 @@
1
- from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Tuple, Type, TypeVar
1
+ import sys
2
+ from types import ModuleType
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, cast
2
4
 
3
5
  from vellum.workflows.errors.types import VellumError, VellumErrorCode
4
6
  from vellum.workflows.exceptions import NodeException
5
7
  from vellum.workflows.nodes.bases import BaseNode
6
8
  from vellum.workflows.nodes.bases.base import BaseNodeMeta
7
- from vellum.workflows.outputs.base import BaseOutputs
9
+ from vellum.workflows.nodes.utils import ADORNMENT_MODULE_NAME
10
+ from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
8
11
  from vellum.workflows.types.generics import StateType
9
12
 
10
13
  if TYPE_CHECKING:
@@ -56,34 +59,60 @@ class TryNode(BaseNode[StateType], Generic[StateType], metaclass=_TryNodeMeta):
56
59
  class Outputs(BaseNode.Outputs):
57
60
  error: Optional[VellumError] = None
58
61
 
59
- def run(self) -> Outputs:
62
+ def run(self) -> Iterator[BaseOutput]:
60
63
  subworkflow = self.subworkflow(
61
64
  parent_state=self.state,
62
65
  context=self._context,
63
66
  )
64
- terminal_event = subworkflow.run()
65
-
66
- if terminal_event.name == "workflow.execution.fulfilled":
67
- outputs = self.Outputs()
68
- for descriptor, value in terminal_event.outputs:
69
- setattr(outputs, descriptor.name, value)
70
- return outputs
71
- elif terminal_event.name == "workflow.execution.paused":
67
+ subworkflow_stream = subworkflow.stream()
68
+
69
+ outputs: Optional[BaseOutputs] = None
70
+ exception: Optional[NodeException] = None
71
+ fulfilled_output_names: Set[str] = set()
72
+
73
+ for event in subworkflow_stream:
74
+ if exception:
75
+ continue
76
+
77
+ if event.name == "workflow.execution.streaming":
78
+ if event.output.is_fulfilled:
79
+ fulfilled_output_names.add(event.output.name)
80
+ yield event.output
81
+ elif event.name == "workflow.execution.fulfilled":
82
+ outputs = event.outputs
83
+ elif event.name == "workflow.execution.paused":
84
+ exception = NodeException(
85
+ code=VellumErrorCode.INVALID_OUTPUTS,
86
+ message="Subworkflow unexpectedly paused within Try Node",
87
+ )
88
+ elif event.name == "workflow.execution.rejected":
89
+ if self.on_error_code and self.on_error_code != event.error.code:
90
+ exception = NodeException(
91
+ code=VellumErrorCode.INVALID_OUTPUTS,
92
+ message=f"""Unexpected rejection: {event.error.code.value}.
93
+ Message: {event.error.message}""",
94
+ )
95
+ else:
96
+ outputs = self.Outputs(error=event.error)
97
+
98
+ if exception:
99
+ raise exception
100
+
101
+ if outputs is None:
72
102
  raise NodeException(
73
103
  code=VellumErrorCode.INVALID_OUTPUTS,
74
- message="Subworkflow unexpectedly paused within Try Node",
75
- )
76
- elif self.on_error_code and self.on_error_code != terminal_event.error.code:
77
- raise NodeException(
78
- code=VellumErrorCode.INVALID_OUTPUTS,
79
- message=f"""Unexpected rejection: {terminal_event.error.code.value}.
80
- Message: {terminal_event.error.message}""",
81
- )
82
- else:
83
- return self.Outputs(
84
- error=terminal_event.error,
104
+ message="Expected to receive outputs from Try Node's subworkflow",
85
105
  )
86
106
 
107
+ # For any outputs somehow in our final fulfilled outputs array,
108
+ # but not fulfilled by the stream.
109
+ for descriptor, value in outputs:
110
+ if descriptor.name not in fulfilled_output_names:
111
+ yield BaseOutput(
112
+ name=descriptor.name,
113
+ value=value,
114
+ )
115
+
87
116
  @classmethod
88
117
  def wrap(cls, on_error_code: Optional[VellumErrorCode] = None) -> Callable[..., Type["TryNode"]]:
89
118
  _on_error_code = on_error_code
@@ -101,11 +130,20 @@ Message: {terminal_event.error.message}""",
101
130
  class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
102
131
  pass
103
132
 
104
- class WrappedNode(TryNode[StateType]):
105
- on_error_code = _on_error_code
106
-
107
- subworkflow = Subworkflow
108
-
133
+ dynamic_module = f"{inner_cls.__module__}.{inner_cls.__name__}.{ADORNMENT_MODULE_NAME}"
134
+ # This dynamic module allows calls to `type_hints` to work
135
+ sys.modules[dynamic_module] = ModuleType(dynamic_module)
136
+
137
+ # We use a dynamic wrapped node class to be uniquely tied to this `inner_cls` node during serialization
138
+ WrappedNode = type(
139
+ cls.__name__,
140
+ (TryNode,),
141
+ {
142
+ "__module__": dynamic_module,
143
+ "on_error_code": _on_error_code,
144
+ "subworkflow": Subworkflow,
145
+ },
146
+ )
109
147
  return WrappedNode
110
148
 
111
149
  return decorator
@@ -7,6 +7,7 @@ from vellum.workflows.inputs.base import BaseInputs
7
7
  from vellum.workflows.nodes.bases import BaseNode
8
8
  from vellum.workflows.nodes.core.try_node.node import TryNode
9
9
  from vellum.workflows.outputs import BaseOutputs
10
+ from vellum.workflows.outputs.base import BaseOutput
10
11
  from vellum.workflows.state.base import BaseState, StateMeta
11
12
  from vellum.workflows.state.context import WorkflowContext
12
13
 
@@ -23,11 +24,15 @@ def test_try_node__on_error_code__successfully_caught():
23
24
 
24
25
  # WHEN the node is run and throws a PROVIDER_ERROR
25
26
  node = TestNode(state=BaseState())
26
- outputs = node.run()
27
-
28
- # THEN the exception is retried
29
- assert outputs == {
30
- "error": VellumError(message="This will be caught", code=VellumErrorCode.PROVIDER_ERROR),
27
+ outputs = [o for o in node.run()]
28
+
29
+ # THEN the exception is caught and returned
30
+ assert len(outputs) == 2
31
+ assert set(outputs) == {
32
+ BaseOutput(name="value"),
33
+ BaseOutput(
34
+ name="error", value=VellumError(message="This will be caught", code=VellumErrorCode.PROVIDER_ERROR)
35
+ ),
31
36
  }
32
37
 
33
38
 
@@ -44,7 +49,7 @@ def test_try_node__retry_on_error_code__missed():
44
49
  # WHEN the node is run and throws a different exception
45
50
  node = TestNode(state=BaseState())
46
51
  with pytest.raises(NodeException) as exc_info:
47
- node.run()
52
+ list(node.run())
48
53
 
49
54
  # THEN the exception is not caught
50
55
  assert exc_info.value.message == "Unexpected rejection: INTERNAL_ERROR.\nMessage: This will be missed"
@@ -78,10 +83,11 @@ def test_try_node__use_parent_inputs_and_state():
78
83
  meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
79
84
  ),
80
85
  )
81
- outputs = node.run()
86
+ outputs = list(node.run())
82
87
 
83
88
  # THEN the data is used successfully
84
- assert outputs == {"value": "foo bar"}
89
+ assert len(outputs) == 1
90
+ assert outputs[-1] == BaseOutput(name="value", value="foo bar")
85
91
 
86
92
 
87
93
  def test_try_node__use_parent_execution_context():
@@ -100,7 +106,8 @@ def test_try_node__use_parent_execution_context():
100
106
  _vellum_client=Vellum(api_key="test-key"),
101
107
  )
102
108
  )
103
- outputs = node.run()
109
+ outputs = list(node.run())
104
110
 
105
111
  # THEN the inner node had access to the key
106
- assert outputs == {"key": "test-key"}
112
+ assert len(outputs) == 1
113
+ assert outputs[-1] == BaseOutput(name="key", value="test-key")