vellum-ai 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. vellum/client/core/client_wrapper.py +2 -2
  2. vellum/client/types/integration_name.py +1 -0
  3. vellum/workflows/expressions/concat.py +6 -3
  4. vellum/workflows/expressions/tests/test_concat.py +63 -8
  5. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +20 -5
  6. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +11 -7
  7. vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +42 -0
  8. vellum/workflows/nodes/displayable/final_output_node/node.py +7 -1
  9. vellum/workflows/nodes/displayable/final_output_node/tests/test_node.py +28 -0
  10. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +84 -56
  11. vellum/workflows/nodes/experimental/__init__.py +1 -3
  12. vellum/workflows/runner/runner.py +144 -0
  13. vellum/workflows/state/context.py +59 -7
  14. vellum/workflows/workflows/base.py +17 -0
  15. vellum/workflows/workflows/event_filters.py +13 -0
  16. vellum/workflows/workflows/tests/test_event_filters.py +126 -0
  17. {vellum_ai-1.8.1.dist-info → vellum_ai-1.8.3.dist-info}/METADATA +1 -1
  18. {vellum_ai-1.8.1.dist-info → vellum_ai-1.8.3.dist-info}/RECORD +23 -23
  19. vellum_ee/workflows/display/utils/expressions.py +4 -0
  20. vellum_ee/workflows/display/utils/tests/test_expressions.py +86 -0
  21. vellum/workflows/nodes/experimental/openai_chat_completion_node/__init__.py +0 -5
  22. vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +0 -266
  23. {vellum_ai-1.8.1.dist-info → vellum_ai-1.8.3.dist-info}/LICENSE +0 -0
  24. {vellum_ai-1.8.1.dist-info → vellum_ai-1.8.3.dist-info}/WHEEL +0 -0
  25. {vellum_ai-1.8.1.dist-info → vellum_ai-1.8.3.dist-info}/entry_points.txt +0 -0
@@ -27,10 +27,10 @@ class BaseClientWrapper:
27
27
 
28
28
  def get_headers(self) -> typing.Dict[str, str]:
29
29
  headers: typing.Dict[str, str] = {
30
- "User-Agent": "vellum-ai/1.8.1",
30
+ "User-Agent": "vellum-ai/1.8.3",
31
31
  "X-Fern-Language": "Python",
32
32
  "X-Fern-SDK-Name": "vellum-ai",
33
- "X-Fern-SDK-Version": "1.8.1",
33
+ "X-Fern-SDK-Version": "1.8.3",
34
34
  **(self.get_custom_headers() or {}),
35
35
  }
36
36
  if self._api_version is not None:
@@ -41,6 +41,7 @@ IntegrationName = typing.Union[
41
41
  "JIRA",
42
42
  "KLAVIYO",
43
43
  "PAGERDUTY",
44
+ "PARSERA",
44
45
  "PEOPLEDATALABS",
45
46
  "PERPLEXITY",
46
47
  "POSTHOG",
@@ -26,7 +26,10 @@ class ConcatExpression(BaseDescriptor[list], Generic[LHS, RHS]):
26
26
 
27
27
  if not isinstance(lval, list):
28
28
  raise InvalidExpressionException(f"Expected LHS to be a list, got {type(lval)}")
29
- if not isinstance(rval, list):
30
- raise InvalidExpressionException(f"Expected RHS to be a list, got {type(rval)}")
31
29
 
32
- return lval + rval
30
+ # If RHS is a list, concatenate normally
31
+ if isinstance(rval, list):
32
+ return lval + rval
33
+ # If RHS is not a list, append it as a single item
34
+ else:
35
+ return lval + [rval]
@@ -38,16 +38,71 @@ def test_concat_expression_lhs_fail():
38
38
  assert "Expected LHS to be a list, got <class 'int'>" in str(exc_info.value)
39
39
 
40
40
 
41
- def test_concat_expression_rhs_fail():
42
- # GIVEN a list lhs and a non-list rhs
41
+ def test_concat_expression_with_single_item():
42
+ # GIVEN a list and a single item
43
43
  state = TestState()
44
44
  lhs_ref = ConstantValueReference([1, 2, 3])
45
- rhs_ref = ConstantValueReference(False)
45
+ rhs_ref = ConstantValueReference(4)
46
46
  concat_expr = lhs_ref.concat(rhs_ref)
47
47
 
48
- # WHEN we attempt to resolve the expression
49
- with pytest.raises(InvalidExpressionException) as exc_info:
50
- concat_expr.resolve(state)
48
+ # WHEN we resolve the expression
49
+ result = concat_expr.resolve(state)
51
50
 
52
- # THEN an exception should be raised
53
- assert "Expected RHS to be a list, got <class 'bool'>" in str(exc_info.value)
51
+ # THEN the item should be appended to the list
52
+ assert result == [1, 2, 3, 4]
53
+
54
+
55
+ def test_concat_expression_with_string():
56
+ # GIVEN a list and a string item
57
+ state = TestState()
58
+ lhs_ref = ConstantValueReference(["hello", "world"])
59
+ rhs_ref = ConstantValueReference("!")
60
+ concat_expr = lhs_ref.concat(rhs_ref)
61
+
62
+ # WHEN we resolve the expression
63
+ result = concat_expr.resolve(state)
64
+
65
+ # THEN the string should be appended to the list
66
+ assert result == ["hello", "world", "!"]
67
+
68
+
69
+ def test_concat_expression_with_none():
70
+ # GIVEN a list and None item
71
+ state = TestState()
72
+ lhs_ref = ConstantValueReference([1, 2])
73
+ rhs_ref = ConstantValueReference(None)
74
+ concat_expr = lhs_ref.concat(rhs_ref)
75
+
76
+ # WHEN we resolve the expression
77
+ result = concat_expr.resolve(state)
78
+
79
+ # THEN None should be appended to the list
80
+ assert result == [1, 2, None]
81
+
82
+
83
+ def test_concat_expression_with_list_item():
84
+ # GIVEN a list and another list
85
+ state = TestState()
86
+ lhs_ref = ConstantValueReference([1, 2])
87
+ rhs_ref = ConstantValueReference([3, 4])
88
+ concat_expr = lhs_ref.concat(rhs_ref)
89
+
90
+ # WHEN we resolve the expression
91
+ result = concat_expr.resolve(state)
92
+
93
+ # THEN the lists should be concatenated normally
94
+ assert result == [1, 2, 3, 4]
95
+
96
+
97
+ def test_concat_expression_empty_list():
98
+ # GIVEN an empty list and an item
99
+ state = TestState()
100
+ lhs_ref: ConstantValueReference[list] = ConstantValueReference([])
101
+ rhs_ref = ConstantValueReference("first")
102
+ concat_expr = lhs_ref.concat(rhs_ref)
103
+
104
+ # WHEN we resolve the expression
105
+ result = concat_expr.resolve(state)
106
+
107
+ # THEN the item should be appended to the empty list
108
+ assert result == ["first"]
@@ -1,5 +1,6 @@
1
1
  from abc import abstractmethod
2
- from typing import ClassVar, Generator, Generic, Iterator, List, Optional, Union
2
+ from itertools import chain
3
+ from typing import ClassVar, Generator, Generic, Iterator, List, Optional, Union, cast
3
4
 
4
5
  from vellum import AdHocExecutePromptEvent, ExecutePromptEvent, PromptOutput
5
6
  from vellum.client.core import RequestOptions
@@ -60,15 +61,26 @@ class BasePromptNode(BaseNode[StateType], Generic[StateType]):
60
61
  except ApiError as e:
61
62
  self._handle_api_error(e)
62
63
 
63
- # We don't use the INITIATED event anyway, so we can just skip it
64
- # and use the exception handling to catch other api level errors
65
64
  try:
66
- next(prompt_event_stream)
65
+ first_event = next(prompt_event_stream)
67
66
  except ApiError as e:
68
67
  self._handle_api_error(e)
68
+ else:
69
+ if first_event.state == "REJECTED":
70
+ workflow_error = vellum_error_to_workflow_error(first_event.error)
71
+ raise NodeException.of(workflow_error)
72
+ if first_event.state != "INITIATED":
73
+ prompt_event_stream = cast(
74
+ Union[Iterator[AdHocExecutePromptEvent], Iterator[ExecutePromptEvent]],
75
+ chain([first_event], prompt_event_stream),
76
+ )
69
77
 
70
78
  outputs: Optional[List[PromptOutput]] = None
79
+ exception: Optional[NodeException] = None
71
80
  for event in prompt_event_stream:
81
+ if exception:
82
+ continue
83
+
72
84
  if event.state == "INITIATED":
73
85
  continue
74
86
  elif event.state == "STREAMING":
@@ -78,7 +90,10 @@ class BasePromptNode(BaseNode[StateType], Generic[StateType]):
78
90
  yield BaseOutput(name="results", value=event.outputs)
79
91
  elif event.state == "REJECTED":
80
92
  workflow_error = vellum_error_to_workflow_error(event.error)
81
- raise NodeException.of(workflow_error)
93
+ exception = NodeException.of(workflow_error)
94
+
95
+ if exception:
96
+ raise exception
82
97
 
83
98
  return outputs
84
99
 
@@ -1,3 +1,4 @@
1
+ from itertools import chain
1
2
  import json
2
3
  from uuid import uuid4
3
4
  from typing import Callable, ClassVar, Generator, Generic, Iterator, List, Optional, Set, Tuple, Union
@@ -225,13 +226,16 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
225
226
  except ApiError as e:
226
227
  self._handle_api_error(e)
227
228
 
228
- if not self.settings or (self.settings and self.settings.stream_enabled):
229
- # We don't use the INITIATED event anyway, so we can just skip it
230
- # and use the exception handling to catch other api level errors
231
- try:
232
- next(prompt_event_stream)
233
- except ApiError as e:
234
- self._handle_api_error(e)
229
+ try:
230
+ first_event = next(prompt_event_stream)
231
+ except ApiError as e:
232
+ self._handle_api_error(e)
233
+ else:
234
+ if first_event.state == "REJECTED":
235
+ workflow_error = vellum_error_to_workflow_error(first_event.error)
236
+ raise NodeException.of(workflow_error)
237
+ if first_event.state != "INITIATED":
238
+ prompt_event_stream = chain([first_event], prompt_event_stream)
235
239
 
236
240
  outputs: Optional[List[PromptOutput]] = None
237
241
  for event in prompt_event_stream:
@@ -25,6 +25,7 @@ from vellum import (
25
25
  PromptRequestStringInput,
26
26
  PromptRequestVideoInput,
27
27
  PromptSettings,
28
+ RejectedExecutePromptEvent,
28
29
  RichTextPromptBlock,
29
30
  StringVellumValue,
30
31
  VariablePromptBlock,
@@ -32,6 +33,7 @@ from vellum import (
32
33
  VellumAudioRequest,
33
34
  VellumDocument,
34
35
  VellumDocumentRequest,
36
+ VellumError,
35
37
  VellumImage,
36
38
  VellumImageRequest,
37
39
  VellumVideo,
@@ -896,3 +898,43 @@ def test_inline_prompt_node__json_output_with_markdown_code_blocks(vellum_adhoc_
896
898
  json_output = outputs[2]
897
899
  assert json_output.name == "json"
898
900
  assert json_output.value == expected_json
901
+
902
+
903
+ def test_inline_prompt_node__provider_error_from_api(vellum_adhoc_prompt_client):
904
+ """
905
+ Tests that InlinePromptNode raises NodeException with PROVIDER_ERROR code when first event is REJECTED.
906
+ """
907
+
908
+ # GIVEN an InlinePromptNode with basic configuration
909
+ class TestNode(InlinePromptNode):
910
+ ml_model = "test-model"
911
+ blocks = []
912
+ prompt_inputs = {}
913
+
914
+ # AND the API returns a REJECTED event as the first event with a provider error
915
+ provider_error = VellumError(
916
+ code="PROVIDER_ERROR",
917
+ message="Provider rate limit exceeded",
918
+ )
919
+
920
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
921
+ execution_id = str(uuid4())
922
+ events: List[ExecutePromptEvent] = [
923
+ RejectedExecutePromptEvent(
924
+ execution_id=execution_id,
925
+ error=provider_error,
926
+ ),
927
+ ]
928
+ yield from events
929
+
930
+ vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
931
+
932
+ # WHEN the node is run
933
+ node = TestNode()
934
+
935
+ # THEN it should raise a NodeException with PROVIDER_ERROR error code
936
+ with pytest.raises(NodeException) as excinfo:
937
+ list(node.run())
938
+
939
+ # AND the exception should have the correct error code
940
+ assert excinfo.value.code == WorkflowErrorCode.PROVIDER_ERROR
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Generic, Tuple, Type, TypeVar, get_args
1
+ from typing import Any, Dict, Generic, Tuple, Type, TypeVar, get_args, get_origin
2
2
 
3
3
  from vellum.workflows.constants import undefined
4
4
  from vellum.workflows.nodes.bases import BaseNode
@@ -74,6 +74,12 @@ class _FinalOutputNodeMeta(BaseNodeMeta):
74
74
  if descriptor_type == declared_output_type:
75
75
  type_mismatch = False
76
76
  break
77
+ if (
78
+ get_origin(descriptor_type) == declared_output_type
79
+ or get_origin(declared_output_type) == descriptor_type
80
+ ):
81
+ type_mismatch = False
82
+ break
77
83
  try:
78
84
  if issubclass(descriptor_type, declared_output_type) or issubclass(
79
85
  declared_output_type, descriptor_type
@@ -1,8 +1,10 @@
1
1
  import pytest
2
+ from typing import Any, Dict
2
3
 
3
4
  from vellum.workflows.exceptions import NodeException
4
5
  from vellum.workflows.nodes.displayable.final_output_node import FinalOutputNode
5
6
  from vellum.workflows.nodes.displayable.inline_prompt_node import InlinePromptNode
7
+ from vellum.workflows.references.output import OutputReference
6
8
  from vellum.workflows.state.base import BaseState
7
9
 
8
10
 
@@ -57,3 +59,29 @@ def test_final_output_node__matching_output_type_should_pass_validation():
57
59
  CorrectOutput.__validate__()
58
60
  except ValueError:
59
61
  pytest.fail("Validation should not raise an exception for correct type matching")
62
+
63
+
64
+ def test_final_output_node__dict_and_Dict_should_be_compatible():
65
+ """
66
+ Tests that FinalOutputNode validation recognizes dict and Dict[str, Any] as compatible types.
67
+ """
68
+
69
+ # GIVEN a FinalOutputNode declared with dict output type
70
+ # AND the value descriptor has Dict[str, Any] type
71
+ class DictOutputNode(FinalOutputNode[BaseState, dict]):
72
+ """Output with dict type."""
73
+
74
+ class Outputs(FinalOutputNode.Outputs):
75
+ value = OutputReference(
76
+ name="value",
77
+ types=(Dict[str, Any],),
78
+ instance=None,
79
+ outputs_class=FinalOutputNode.Outputs,
80
+ )
81
+
82
+ # WHEN attempting to validate the node class
83
+ # THEN validation should pass without raising an exception
84
+ try:
85
+ DictOutputNode.__validate__()
86
+ except ValueError as e:
87
+ pytest.fail(f"Validation should not raise an exception for dict/Dict compatibility: {e}")
@@ -1,6 +1,6 @@
1
1
  import json
2
- from uuid import UUID
3
- from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Union, cast
2
+ from uuid import UUID, uuid4
3
+ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Type, Union, cast
4
4
 
5
5
  from vellum import (
6
6
  ChatMessage,
@@ -19,12 +19,13 @@ from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT, undefined
19
19
  from vellum.workflows.context import execution_context, get_execution_context, get_parent_context
20
20
  from vellum.workflows.errors import WorkflowErrorCode
21
21
  from vellum.workflows.errors.types import workflow_event_error_to_workflow_error
22
- from vellum.workflows.events.types import default_serializer
22
+ from vellum.workflows.events.types import WorkflowDeploymentParentContext, default_serializer
23
23
  from vellum.workflows.events.workflow import is_workflow_event
24
24
  from vellum.workflows.exceptions import NodeException, WorkflowInitializationException
25
25
  from vellum.workflows.inputs.base import BaseInputs
26
26
  from vellum.workflows.nodes.bases.base import BaseNode
27
27
  from vellum.workflows.outputs.base import BaseOutput
28
+ from vellum.workflows.state.context import WorkflowContext, WorkflowDeploymentMetadata
28
29
  from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
29
30
  from vellum.workflows.types.generics import StateType
30
31
  from vellum.workflows.workflows.event_filters import all_workflow_event_filter
@@ -155,69 +156,95 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
155
156
  filtered_inputs = {k: v for k, v in self.subworkflow_inputs.items() if v is not None}
156
157
  return inputs_class(**filtered_inputs)
157
158
 
158
- def _run_resolved_workflow(self, resolved_workflow: "BaseWorkflow") -> Iterator[BaseOutput]:
159
+ def _run_resolved_workflow(
160
+ self,
161
+ workflow_class: Type["BaseWorkflow"],
162
+ deployment_metadata: Optional[WorkflowDeploymentMetadata],
163
+ ) -> Iterator[BaseOutput]:
159
164
  """Execute resolved workflow directly (similar to InlineSubworkflowNode)."""
160
- with execution_context(parent_context=get_parent_context()):
165
+ # Construct the parent context hierarchy for the subworkflow
166
+ parent_context = get_parent_context()
167
+
168
+ # If we have deployment metadata, wrap the parent context with WorkflowDeploymentParentContext
169
+ if deployment_metadata:
170
+ parent_context = WorkflowDeploymentParentContext(
171
+ span_id=uuid4(),
172
+ deployment_id=deployment_metadata.deployment_id,
173
+ deployment_name=deployment_metadata.deployment_name,
174
+ deployment_history_item_id=deployment_metadata.deployment_history_item_id,
175
+ release_tag_id=deployment_metadata.release_tag_id,
176
+ release_tag_name=deployment_metadata.release_tag_name,
177
+ workflow_version_id=deployment_metadata.workflow_version_id,
178
+ external_id=self.external_id if self.external_id is not OMIT else None,
179
+ metadata=self.metadata if self.metadata is not OMIT else None,
180
+ parent=parent_context,
181
+ )
182
+
183
+ with execution_context(parent_context=parent_context):
184
+ # Instantiate the workflow inside the execution context so it captures the correct parent context
185
+ resolved_workflow = workflow_class(
186
+ context=WorkflowContext.create_from(self._context), parent_state=self.state
187
+ )
161
188
  subworkflow_stream = resolved_workflow.stream(
162
189
  inputs=self._compile_subworkflow_inputs_for_direct_invocation(resolved_workflow),
163
190
  event_filter=all_workflow_event_filter,
164
191
  node_output_mocks=self._context._get_all_node_output_mocks(),
165
192
  )
166
193
 
167
- try:
168
- first_event = next(subworkflow_stream)
169
- self._context._emit_subworkflow_event(first_event)
170
- except WorkflowInitializationException as e:
171
- hashed_module = e.definition.__module__
172
- raise NodeException(
173
- message=e.message,
174
- code=e.code,
175
- raw_data={"hashed_module": hashed_module},
176
- ) from e
177
-
178
- outputs = None
179
- exception = None
180
- fulfilled_output_names: Set[str] = set()
194
+ try:
195
+ first_event = next(subworkflow_stream)
196
+ self._context._emit_subworkflow_event(first_event)
197
+ except WorkflowInitializationException as e:
198
+ hashed_module = e.definition.__module__
199
+ raise NodeException(
200
+ message=e.message,
201
+ code=e.code,
202
+ raw_data={"hashed_module": hashed_module},
203
+ ) from e
204
+
205
+ outputs = None
206
+ exception = None
207
+ fulfilled_output_names: Set[str] = set()
208
+
209
+ for event in subworkflow_stream:
210
+ self._context._emit_subworkflow_event(event)
211
+ if exception:
212
+ continue
213
+
214
+ if not is_workflow_event(event):
215
+ continue
216
+ if event.workflow_definition != resolved_workflow.__class__:
217
+ continue
218
+
219
+ if event.name == "workflow.execution.streaming":
220
+ if event.output.is_fulfilled:
221
+ fulfilled_output_names.add(event.output.name)
222
+ yield event.output
223
+ elif event.name == "workflow.execution.fulfilled":
224
+ outputs = event.outputs
225
+ elif event.name == "workflow.execution.rejected":
226
+ exception = NodeException.of(event.error)
227
+ elif event.name == "workflow.execution.paused":
228
+ exception = NodeException(
229
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
230
+ message="Subworkflow unexpectedly paused",
231
+ )
181
232
 
182
- for event in subworkflow_stream:
183
- self._context._emit_subworkflow_event(event)
184
233
  if exception:
185
- continue
234
+ raise exception
186
235
 
187
- if not is_workflow_event(event):
188
- continue
189
- if event.workflow_definition != resolved_workflow.__class__:
190
- continue
191
-
192
- if event.name == "workflow.execution.streaming":
193
- if event.output.is_fulfilled:
194
- fulfilled_output_names.add(event.output.name)
195
- yield event.output
196
- elif event.name == "workflow.execution.fulfilled":
197
- outputs = event.outputs
198
- elif event.name == "workflow.execution.rejected":
199
- exception = NodeException.of(event.error)
200
- elif event.name == "workflow.execution.paused":
201
- exception = NodeException(
236
+ if outputs is None:
237
+ raise NodeException(
238
+ message="Expected to receive outputs from Workflow Deployment",
202
239
  code=WorkflowErrorCode.INVALID_OUTPUTS,
203
- message="Subworkflow unexpectedly paused",
204
240
  )
205
241
 
206
- if exception:
207
- raise exception
208
-
209
- if outputs is None:
210
- raise NodeException(
211
- message="Expected to receive outputs from Workflow Deployment",
212
- code=WorkflowErrorCode.INVALID_OUTPUTS,
213
- )
214
-
215
- for output_descriptor, output_value in outputs:
216
- if output_descriptor.name not in fulfilled_output_names:
217
- yield BaseOutput(
218
- name=output_descriptor.name,
219
- value=output_value,
220
- )
242
+ for output_descriptor, output_value in outputs:
243
+ if output_descriptor.name not in fulfilled_output_names:
244
+ yield BaseOutput(
245
+ name=output_descriptor.name,
246
+ value=output_value,
247
+ )
221
248
 
222
249
  def run(self) -> Iterator[BaseOutput]:
223
250
  execution_context = get_execution_context()
@@ -248,11 +275,12 @@ class SubworkflowDeploymentNode(BaseNode[StateType], Generic[StateType]):
248
275
  message="Expected deployment name to be provided for subworkflow execution.",
249
276
  )
250
277
 
251
- resolved_workflow = self._context.resolve_workflow_deployment(
278
+ resolved_result = self._context.resolve_workflow_deployment(
252
279
  deployment_name=deployment_name, release_tag=self.release_tag, state=self.state
253
280
  )
254
- if resolved_workflow:
255
- yield from self._run_resolved_workflow(resolved_workflow)
281
+ if resolved_result:
282
+ workflow_class, deployment_metadata = resolved_result
283
+ yield from self._run_resolved_workflow(workflow_class, deployment_metadata)
256
284
  return
257
285
 
258
286
  try:
@@ -1,3 +1 @@
1
- from .openai_chat_completion_node import OpenAIChatCompletionNode
2
-
3
- __all__ = ["OpenAIChatCompletionNode"]
1
+ __all__: list[str] = []