vellum-ai 0.14.37__py3-none-any.whl → 0.14.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. vellum/__init__.py +8 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/reference.md +6272 -0
  4. vellum/client/types/__init__.py +8 -0
  5. vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
  6. vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
  7. vellum/client/types/test_suite_run_exec_config_request.py +4 -0
  8. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
  10. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
  11. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
  12. vellum/plugins/pydantic.py +1 -1
  13. vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
  14. vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
  15. vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
  16. vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
  17. vellum/workflows/events/node.py +2 -1
  18. vellum/workflows/events/types.py +3 -40
  19. vellum/workflows/events/workflow.py +2 -1
  20. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
  21. vellum/workflows/nodes/displayable/conftest.py +2 -6
  22. vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
  23. vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
  24. vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
  25. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +297 -0
  26. vellum/workflows/runner/runner.py +44 -43
  27. vellum/workflows/state/base.py +149 -45
  28. vellum/workflows/types/definition.py +71 -0
  29. vellum/workflows/types/generics.py +34 -1
  30. vellum/workflows/workflows/base.py +20 -3
  31. vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
  32. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/METADATA +1 -1
  33. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/RECORD +37 -25
  34. vellum_ee/workflows/display/vellum.py +0 -5
  35. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/LICENSE +0 -0
  36. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/WHEEL +0 -0
  37. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/entry_points.txt +0 -0
@@ -453,6 +453,8 @@ from .test_suite_run_metric_json_output import TestSuiteRunMetricJsonOutput
453
453
  from .test_suite_run_metric_number_output import TestSuiteRunMetricNumberOutput
454
454
  from .test_suite_run_metric_output import TestSuiteRunMetricOutput
455
455
  from .test_suite_run_metric_string_output import TestSuiteRunMetricStringOutput
456
+ from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
457
+ from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
456
458
  from .test_suite_run_prompt_sandbox_history_item_exec_config import TestSuiteRunPromptSandboxHistoryItemExecConfig
457
459
  from .test_suite_run_prompt_sandbox_history_item_exec_config_data import (
458
460
  TestSuiteRunPromptSandboxHistoryItemExecConfigData,
@@ -472,6 +474,8 @@ from .test_suite_run_workflow_release_tag_exec_config_data_request import (
472
474
  TestSuiteRunWorkflowReleaseTagExecConfigDataRequest,
473
475
  )
474
476
  from .test_suite_run_workflow_release_tag_exec_config_request import TestSuiteRunWorkflowReleaseTagExecConfigRequest
477
+ from .test_suite_run_workflow_sandbox_exec_config_data_request import TestSuiteRunWorkflowSandboxExecConfigDataRequest
478
+ from .test_suite_run_workflow_sandbox_exec_config_request import TestSuiteRunWorkflowSandboxExecConfigRequest
475
479
  from .test_suite_run_workflow_sandbox_history_item_exec_config import TestSuiteRunWorkflowSandboxHistoryItemExecConfig
476
480
  from .test_suite_run_workflow_sandbox_history_item_exec_config_data import (
477
481
  TestSuiteRunWorkflowSandboxHistoryItemExecConfigData,
@@ -1048,6 +1052,8 @@ __all__ = [
1048
1052
  "TestSuiteRunMetricNumberOutput",
1049
1053
  "TestSuiteRunMetricOutput",
1050
1054
  "TestSuiteRunMetricStringOutput",
1055
+ "TestSuiteRunPromptSandboxExecConfigDataRequest",
1056
+ "TestSuiteRunPromptSandboxExecConfigRequest",
1051
1057
  "TestSuiteRunPromptSandboxHistoryItemExecConfig",
1052
1058
  "TestSuiteRunPromptSandboxHistoryItemExecConfigData",
1053
1059
  "TestSuiteRunPromptSandboxHistoryItemExecConfigDataRequest",
@@ -1059,6 +1065,8 @@ __all__ = [
1059
1065
  "TestSuiteRunWorkflowReleaseTagExecConfigData",
1060
1066
  "TestSuiteRunWorkflowReleaseTagExecConfigDataRequest",
1061
1067
  "TestSuiteRunWorkflowReleaseTagExecConfigRequest",
1068
+ "TestSuiteRunWorkflowSandboxExecConfigDataRequest",
1069
+ "TestSuiteRunWorkflowSandboxExecConfigRequest",
1062
1070
  "TestSuiteRunWorkflowSandboxHistoryItemExecConfig",
1063
1071
  "TestSuiteRunWorkflowSandboxHistoryItemExecConfigData",
1064
1072
  "TestSuiteRunWorkflowSandboxHistoryItemExecConfigDataRequest",
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .finish_reason_enum import FinishReasonEnum
6
6
  from .ml_model_usage import MlModelUsage
7
+ from .price import Price
7
8
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
9
  import pydantic
9
10
 
@@ -16,6 +17,7 @@ class AdHocFulfilledPromptExecutionMeta(UniversalBaseModel):
16
17
  latency: typing.Optional[int] = None
17
18
  finish_reason: typing.Optional[FinishReasonEnum] = None
18
19
  usage: typing.Optional[MlModelUsage] = None
20
+ cost: typing.Optional[Price] = None
19
21
 
20
22
  if IS_PYDANTIC_V2:
21
23
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .finish_reason_enum import FinishReasonEnum
6
6
  from .ml_model_usage import MlModelUsage
7
+ from .price import Price
7
8
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
9
  import pydantic
9
10
 
@@ -16,6 +17,7 @@ class FulfilledPromptExecutionMeta(UniversalBaseModel):
16
17
  latency: typing.Optional[int] = None
17
18
  finish_reason: typing.Optional[FinishReasonEnum] = None
18
19
  usage: typing.Optional[MlModelUsage] = None
20
+ cost: typing.Optional[Price] = None
19
21
 
20
22
  if IS_PYDANTIC_V2:
21
23
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -2,10 +2,12 @@
2
2
 
3
3
  import typing
4
4
  from .test_suite_run_deployment_release_tag_exec_config_request import TestSuiteRunDeploymentReleaseTagExecConfigRequest
5
+ from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
5
6
  from .test_suite_run_prompt_sandbox_history_item_exec_config_request import (
6
7
  TestSuiteRunPromptSandboxHistoryItemExecConfigRequest,
7
8
  )
8
9
  from .test_suite_run_workflow_release_tag_exec_config_request import TestSuiteRunWorkflowReleaseTagExecConfigRequest
10
+ from .test_suite_run_workflow_sandbox_exec_config_request import TestSuiteRunWorkflowSandboxExecConfigRequest
9
11
  from .test_suite_run_workflow_sandbox_history_item_exec_config_request import (
10
12
  TestSuiteRunWorkflowSandboxHistoryItemExecConfigRequest,
11
13
  )
@@ -13,8 +15,10 @@ from .test_suite_run_external_exec_config_request import TestSuiteRunExternalExe
13
15
 
14
16
  TestSuiteRunExecConfigRequest = typing.Union[
15
17
  TestSuiteRunDeploymentReleaseTagExecConfigRequest,
18
+ TestSuiteRunPromptSandboxExecConfigRequest,
16
19
  TestSuiteRunPromptSandboxHistoryItemExecConfigRequest,
17
20
  TestSuiteRunWorkflowReleaseTagExecConfigRequest,
21
+ TestSuiteRunWorkflowSandboxExecConfigRequest,
18
22
  TestSuiteRunWorkflowSandboxHistoryItemExecConfigRequest,
19
23
  TestSuiteRunExternalExecConfigRequest,
20
24
  ]
@@ -0,0 +1,27 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import pydantic
5
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
6
+ import typing
7
+
8
+
9
+ class TestSuiteRunPromptSandboxExecConfigDataRequest(UniversalBaseModel):
10
+ prompt_sandbox_id: str = pydantic.Field()
11
+ """
12
+ The ID of the Prompt Sandbox to run the Test Suite against.
13
+ """
14
+
15
+ prompt_variant_id: str = pydantic.Field()
16
+ """
17
+ The ID of the Prompt Variant within the Prompt Sandbox that you'd like to run the Test Suite against.
18
+ """
19
+
20
+ if IS_PYDANTIC_V2:
21
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
22
+ else:
23
+
24
+ class Config:
25
+ frozen = True
26
+ smart_union = True
27
+ extra = pydantic.Extra.allow
@@ -0,0 +1,29 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import typing
5
+ from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
6
+ import pydantic
7
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
+
9
+
10
+ class TestSuiteRunPromptSandboxExecConfigRequest(UniversalBaseModel):
11
+ """
12
+ Execution configuration for running a Test Suite against a Prompt Sandbox
13
+ """
14
+
15
+ type: typing.Literal["PROMPT_SANDBOX"] = "PROMPT_SANDBOX"
16
+ data: TestSuiteRunPromptSandboxExecConfigDataRequest
17
+ test_case_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
18
+ """
19
+ Optionally specify a subset of test case ids to run. If not provided, all test cases within the test suite will be run by default.
20
+ """
21
+
22
+ if IS_PYDANTIC_V2:
23
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
24
+ else:
25
+
26
+ class Config:
27
+ frozen = True
28
+ smart_union = True
29
+ extra = pydantic.Extra.allow
@@ -0,0 +1,22 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import pydantic
5
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
6
+ import typing
7
+
8
+
9
+ class TestSuiteRunWorkflowSandboxExecConfigDataRequest(UniversalBaseModel):
10
+ workflow_sandbox_id: str = pydantic.Field()
11
+ """
12
+ The ID of the Workflow Sandbox to run the Test Suite against.
13
+ """
14
+
15
+ if IS_PYDANTIC_V2:
16
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
17
+ else:
18
+
19
+ class Config:
20
+ frozen = True
21
+ smart_union = True
22
+ extra = pydantic.Extra.allow
@@ -0,0 +1,29 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import typing
5
+ from .test_suite_run_workflow_sandbox_exec_config_data_request import TestSuiteRunWorkflowSandboxExecConfigDataRequest
6
+ import pydantic
7
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
+
9
+
10
+ class TestSuiteRunWorkflowSandboxExecConfigRequest(UniversalBaseModel):
11
+ """
12
+ Execution configuration for running a Test Suite against a Workflow Sandbox
13
+ """
14
+
15
+ type: typing.Literal["WORKFLOW_SANDBOX"] = "WORKFLOW_SANDBOX"
16
+ data: TestSuiteRunWorkflowSandboxExecConfigDataRequest
17
+ test_case_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
18
+ """
19
+ Optionally specify a subset of test case ids to run. If not provided, all test cases within the test suite will be run by default.
20
+ """
21
+
22
+ if IS_PYDANTIC_V2:
23
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
24
+ else:
25
+
26
+ class Config:
27
+ frozen = True
28
+ smart_union = True
29
+ extra = pydantic.Extra.allow
@@ -53,7 +53,7 @@ class OnValidatePython(ValidatePythonHandlerProtocol):
53
53
  return
54
54
 
55
55
  if self_instance:
56
- model_fields: Dict[str, FieldInfo] = self_instance.model_fields
56
+ model_fields: Dict[str, FieldInfo] = self_instance.__class__.model_fields
57
57
  else:
58
58
  model_fields = {}
59
59
 
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_prompt_sandbox_exec_config_data_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_prompt_sandbox_exec_config_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_workflow_sandbox_exec_config_data_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_workflow_sandbox_exec_config_request import *
@@ -8,9 +8,10 @@ from vellum.workflows.expressions.accessor import AccessorExpression
8
8
  from vellum.workflows.outputs.base import BaseOutput
9
9
  from vellum.workflows.ports.port import Port
10
10
  from vellum.workflows.references.node import NodeReference
11
+ from vellum.workflows.types.definition import serialize_type_encoder_with_id
11
12
  from vellum.workflows.types.generics import OutputsType
12
13
 
13
- from .types import BaseEvent, default_serializer, serialize_type_encoder_with_id
14
+ from .types import BaseEvent, default_serializer
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  from vellum.workflows.nodes.bases import BaseNode
@@ -1,13 +1,14 @@
1
1
  from datetime import datetime
2
2
  import json
3
3
  from uuid import UUID, uuid4
4
- from typing import Annotated, Any, Dict, List, Literal, Optional, Union, get_args
4
+ from typing import Annotated, Any, Literal, Optional, Union, get_args
5
5
 
6
- from pydantic import BeforeValidator, Field, GetCoreSchemaHandler, Tag, ValidationInfo
6
+ from pydantic import Field, GetCoreSchemaHandler, Tag, ValidationInfo
7
7
  from pydantic_core import CoreSchema, core_schema
8
8
 
9
9
  from vellum.core.pydantic_utilities import UniversalBaseModel
10
10
  from vellum.workflows.state.encoder import DefaultStateEncoder
11
+ from vellum.workflows.types.definition import VellumCodeResourceDefinition
11
12
  from vellum.workflows.types.utils import datetime_now
12
13
 
13
14
 
@@ -19,28 +20,6 @@ def default_datetime_factory() -> datetime:
19
20
  return datetime_now()
20
21
 
21
22
 
22
- excluded_modules = {"typing", "builtins"}
23
-
24
-
25
- def serialize_type_encoder(obj: type) -> Dict[str, Any]:
26
- return {
27
- "name": obj.__name__,
28
- "module": obj.__module__.split("."),
29
- }
30
-
31
-
32
- def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
33
- if hasattr(obj, "__id__") and isinstance(obj, type):
34
- return {
35
- "id": getattr(obj, "__id__"),
36
- **serialize_type_encoder(obj),
37
- }
38
- elif isinstance(obj, CodeResourceDefinition):
39
- return obj.model_dump(mode="json")
40
-
41
- raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
42
-
43
-
44
23
  def default_serializer(obj: Any) -> Any:
45
24
  return json.loads(
46
25
  json.dumps(
@@ -50,22 +29,6 @@ def default_serializer(obj: Any) -> Any:
50
29
  )
51
30
 
52
31
 
53
- class CodeResourceDefinition(UniversalBaseModel):
54
- id: UUID
55
- name: str
56
- module: List[str]
57
-
58
- @staticmethod
59
- def encode(obj: type) -> "CodeResourceDefinition":
60
- return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
61
-
62
-
63
- VellumCodeResourceDefinition = Annotated[
64
- CodeResourceDefinition,
65
- BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
66
- ]
67
-
68
-
69
32
  class BaseParentContext(UniversalBaseModel):
70
33
  span_id: UUID
71
34
  parent: Optional["ParentContext"] = None
@@ -8,6 +8,7 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
8
8
  from vellum.workflows.errors import WorkflowError
9
9
  from vellum.workflows.outputs.base import BaseOutput
10
10
  from vellum.workflows.references import ExternalInputReference
11
+ from vellum.workflows.types.definition import serialize_type_encoder_with_id
11
12
  from vellum.workflows.types.generics import InputsType, OutputsType, StateType
12
13
 
13
14
  from .node import (
@@ -18,7 +19,7 @@ from .node import (
18
19
  NodeExecutionResumedEvent,
19
20
  NodeExecutionStreamingEvent,
20
21
  )
21
- from .types import BaseEvent, default_serializer, serialize_type_encoder_with_id
22
+ from .types import BaseEvent, default_serializer
22
23
 
23
24
  if TYPE_CHECKING:
24
25
  from vellum.workflows.workflows.base import BaseWorkflow
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from uuid import UUID
3
- from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Sequence, Union
3
+ from typing import Any, ClassVar, Dict, Generator, Generic, Iterator, List, Optional, Sequence, Set, Union
4
4
 
5
5
  from vellum import (
6
6
  ChatHistoryInputRequest,
@@ -9,17 +9,20 @@ from vellum import (
9
9
  JsonInputRequest,
10
10
  PromptDeploymentExpandMetaRequest,
11
11
  PromptDeploymentInputRequest,
12
+ PromptOutput,
12
13
  RawPromptExecutionOverridesRequest,
13
14
  StringInputRequest,
14
15
  )
15
- from vellum.client import RequestOptions
16
+ from vellum.client import ApiError, RequestOptions
16
17
  from vellum.client.types.chat_message_request import ChatMessageRequest
17
18
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
18
19
  from vellum.workflows.context import get_execution_context
19
20
  from vellum.workflows.errors import WorkflowErrorCode
21
+ from vellum.workflows.errors.types import vellum_error_to_workflow_error
20
22
  from vellum.workflows.events.types import default_serializer
21
23
  from vellum.workflows.exceptions import NodeException
22
24
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
25
+ from vellum.workflows.outputs import BaseOutput
23
26
  from vellum.workflows.types import MergeBehavior
24
27
  from vellum.workflows.types.generics import StateType
25
28
 
@@ -56,13 +59,21 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
56
59
  class Trigger(BasePromptNode.Trigger):
57
60
  merge_behavior = MergeBehavior.AWAIT_ANY
58
61
 
59
- def _get_prompt_event_stream(self) -> Iterator[ExecutePromptEvent]:
62
+ def _get_prompt_event_stream(self, ml_model_fallback: Optional[str] = None) -> Iterator[ExecutePromptEvent]:
60
63
  execution_context = get_execution_context()
61
64
  request_options = self.request_options or RequestOptions()
62
65
  request_options["additional_body_parameters"] = {
63
66
  "execution_context": execution_context.model_dump(mode="json"),
64
67
  **request_options.get("additional_body_parameters", {}),
65
68
  }
69
+ if ml_model_fallback:
70
+ request_options["additional_body_parameters"] = {
71
+ "overrides": {
72
+ "ml_model_fallback": ml_model_fallback,
73
+ },
74
+ **request_options.get("additional_body_parameters", {}),
75
+ }
76
+
66
77
  return self._context.vellum_client.execute_prompt_stream(
67
78
  inputs=self._compile_prompt_inputs(),
68
79
  prompt_deployment_id=str(self.deployment) if isinstance(self.deployment, UUID) else None,
@@ -76,6 +87,86 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
76
87
  request_options=request_options,
77
88
  )
78
89
 
90
+ def _process_prompt_event_stream(
91
+ self,
92
+ prompt_event_stream: Optional[Iterator[ExecutePromptEvent]] = None,
93
+ tried_fallbacks: Optional[set[str]] = None,
94
+ ) -> Generator[BaseOutput, None, Optional[List[PromptOutput]]]:
95
+ """Override the base prompt node _process_prompt_event_stream()"""
96
+ self._validate()
97
+
98
+ if tried_fallbacks is None:
99
+ tried_fallbacks = set()
100
+
101
+ if prompt_event_stream is None:
102
+ try:
103
+ prompt_event_stream = self._get_prompt_event_stream()
104
+ next(prompt_event_stream)
105
+ except ApiError as e:
106
+ if (
107
+ e.status_code
108
+ and e.status_code < 500
109
+ and self.ml_model_fallbacks is not OMIT
110
+ and self.ml_model_fallbacks is not None
111
+ ):
112
+ prompt_event_stream = self._retry_prompt_stream_with_fallbacks(tried_fallbacks)
113
+ else:
114
+ self._handle_api_error(e)
115
+
116
+ outputs: Optional[List[PromptOutput]] = None
117
+ if prompt_event_stream is not None:
118
+ for event in prompt_event_stream:
119
+ if event.state == "INITIATED":
120
+ continue
121
+ elif event.state == "STREAMING":
122
+ yield BaseOutput(name="results", delta=event.output.value)
123
+ elif event.state == "FULFILLED":
124
+ outputs = event.outputs
125
+ yield BaseOutput(name="results", value=event.outputs)
126
+ elif event.state == "REJECTED":
127
+ if (
128
+ event.error
129
+ and event.error.code == WorkflowErrorCode.PROVIDER_ERROR.value
130
+ and self.ml_model_fallbacks is not OMIT
131
+ and self.ml_model_fallbacks is not None
132
+ ):
133
+ try:
134
+ fallback_stream = self._retry_prompt_stream_with_fallbacks(tried_fallbacks)
135
+ fallback_outputs = yield from self._process_prompt_event_stream(
136
+ fallback_stream, tried_fallbacks
137
+ )
138
+ return fallback_outputs
139
+ except ApiError:
140
+ pass
141
+
142
+ workflow_error = vellum_error_to_workflow_error(event.error)
143
+ raise NodeException.of(workflow_error)
144
+
145
+ return outputs
146
+
147
+ def _retry_prompt_stream_with_fallbacks(self, tried_fallbacks: Set[str]) -> Optional[Iterator[ExecutePromptEvent]]:
148
+ if self.ml_model_fallbacks is not None:
149
+ for ml_model_fallback in self.ml_model_fallbacks:
150
+ if ml_model_fallback in tried_fallbacks:
151
+ continue
152
+
153
+ try:
154
+ tried_fallbacks.add(ml_model_fallback)
155
+ prompt_event_stream = self._get_prompt_event_stream(ml_model_fallback=ml_model_fallback)
156
+ next(prompt_event_stream)
157
+ return prompt_event_stream
158
+ except ApiError:
159
+ continue
160
+ else:
161
+ self._handle_api_error(
162
+ ApiError(
163
+ body={"detail": f"Failed to execute prompts with these fallbacks: {self.ml_model_fallbacks}"},
164
+ status_code=400,
165
+ )
166
+ )
167
+
168
+ return None
169
+
79
170
  def _compile_prompt_inputs(self) -> List[PromptDeploymentInputRequest]:
80
171
  # TODO: We may want to consolidate with subworkflow deployment input compilation
81
172
  # https://app.shortcut.com/vellum/story/4117
@@ -1,12 +1,8 @@
1
1
  import pytest
2
2
  from uuid import UUID
3
3
 
4
- from vellum.workflows.events.types import (
5
- CodeResourceDefinition,
6
- NodeParentContext,
7
- WorkflowDeploymentParentContext,
8
- WorkflowParentContext,
9
- )
4
+ from vellum.workflows.events.types import NodeParentContext, WorkflowDeploymentParentContext, WorkflowParentContext
5
+ from vellum.workflows.types.definition import CodeResourceDefinition
10
6
 
11
7
 
12
8
  @pytest.fixture
@@ -117,7 +117,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
117
117
  value=cast(Dict[str, Any], input_value),
118
118
  )
119
119
  )
120
- elif isinstance(input_value, float):
120
+ elif isinstance(input_value, (int, float)):
121
121
  compiled_inputs.append(
122
122
  NumberInput(
123
123
  name=input_name,
@@ -0,0 +1,50 @@
1
+ from vellum import TestSuiteRunMetricNumberOutput
2
+ from vellum.client.types.chat_history_input import ChatHistoryInput
3
+ from vellum.client.types.chat_message import ChatMessage
4
+ from vellum.client.types.json_input import JsonInput
5
+ from vellum.client.types.metric_definition_execution import MetricDefinitionExecution
6
+ from vellum.client.types.number_input import NumberInput
7
+ from vellum.client.types.string_input import StringInput
8
+ from vellum.workflows.nodes.displayable.guardrail_node.node import GuardrailNode
9
+
10
+
11
+ def test_guardrail_node__inputs(vellum_client):
12
+ """Test that GuardrailNode correctly handles inputs."""
13
+
14
+ # GIVEN a Guardrail Node with inputs
15
+ class MyGuard(GuardrailNode):
16
+ metric_definition = "example_metric_definition"
17
+ metric_inputs = {
18
+ "a_string": "hello",
19
+ "a_chat_history": [ChatMessage(role="USER", text="Hello, how are you?")],
20
+ "a_dict": {"foo": "bar"},
21
+ "a_int": 42,
22
+ "a_float": 3.14,
23
+ }
24
+
25
+ vellum_client.metric_definitions.execute_metric_definition.return_value = MetricDefinitionExecution(
26
+ outputs=[
27
+ TestSuiteRunMetricNumberOutput(
28
+ name="score",
29
+ value=1.0,
30
+ ),
31
+ ],
32
+ )
33
+
34
+ # WHEN the node is run
35
+ MyGuard().run()
36
+
37
+ # THEN the metric_definitions.execute_metric_definition method should be called with the correct inputs
38
+ mock_api = vellum_client.metric_definitions.execute_metric_definition
39
+ assert mock_api.call_count == 1
40
+
41
+ assert mock_api.call_args.kwargs["inputs"] == [
42
+ StringInput(name="a_string", type="STRING", value="hello"),
43
+ ChatHistoryInput(
44
+ name="a_chat_history", type="CHAT_HISTORY", value=[ChatMessage(role="USER", text="Hello, how are you?")]
45
+ ),
46
+ JsonInput(name="a_dict", type="JSON", value={"foo": "bar"}),
47
+ NumberInput(name="a_int", type="NUMBER", value=42.0),
48
+ NumberInput(name="a_float", type="NUMBER", value=3.14),
49
+ ]
50
+ assert len(mock_api.call_args.kwargs["inputs"]) == 5