vellum-ai 0.14.37__py3-none-any.whl → 0.14.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +8 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +6272 -0
- vellum/client/types/__init__.py +8 -0
- vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/test_suite_run_exec_config_request.py +4 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
- vellum/plugins/pydantic.py +1 -1
- vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
- vellum/workflows/events/node.py +2 -1
- vellum/workflows/events/types.py +3 -40
- vellum/workflows/events/workflow.py +2 -1
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
- vellum/workflows/nodes/displayable/conftest.py +2 -6
- vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
- vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +297 -0
- vellum/workflows/runner/runner.py +44 -43
- vellum/workflows/state/base.py +149 -45
- vellum/workflows/types/definition.py +71 -0
- vellum/workflows/types/generics.py +34 -1
- vellum/workflows/workflows/base.py +20 -3
- vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/RECORD +37 -25
- vellum_ee/workflows/display/vellum.py +0 -5
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/entry_points.txt +0 -0
vellum/client/types/__init__.py
CHANGED
@@ -453,6 +453,8 @@ from .test_suite_run_metric_json_output import TestSuiteRunMetricJsonOutput
|
|
453
453
|
from .test_suite_run_metric_number_output import TestSuiteRunMetricNumberOutput
|
454
454
|
from .test_suite_run_metric_output import TestSuiteRunMetricOutput
|
455
455
|
from .test_suite_run_metric_string_output import TestSuiteRunMetricStringOutput
|
456
|
+
from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
|
457
|
+
from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
|
456
458
|
from .test_suite_run_prompt_sandbox_history_item_exec_config import TestSuiteRunPromptSandboxHistoryItemExecConfig
|
457
459
|
from .test_suite_run_prompt_sandbox_history_item_exec_config_data import (
|
458
460
|
TestSuiteRunPromptSandboxHistoryItemExecConfigData,
|
@@ -472,6 +474,8 @@ from .test_suite_run_workflow_release_tag_exec_config_data_request import (
|
|
472
474
|
TestSuiteRunWorkflowReleaseTagExecConfigDataRequest,
|
473
475
|
)
|
474
476
|
from .test_suite_run_workflow_release_tag_exec_config_request import TestSuiteRunWorkflowReleaseTagExecConfigRequest
|
477
|
+
from .test_suite_run_workflow_sandbox_exec_config_data_request import TestSuiteRunWorkflowSandboxExecConfigDataRequest
|
478
|
+
from .test_suite_run_workflow_sandbox_exec_config_request import TestSuiteRunWorkflowSandboxExecConfigRequest
|
475
479
|
from .test_suite_run_workflow_sandbox_history_item_exec_config import TestSuiteRunWorkflowSandboxHistoryItemExecConfig
|
476
480
|
from .test_suite_run_workflow_sandbox_history_item_exec_config_data import (
|
477
481
|
TestSuiteRunWorkflowSandboxHistoryItemExecConfigData,
|
@@ -1048,6 +1052,8 @@ __all__ = [
|
|
1048
1052
|
"TestSuiteRunMetricNumberOutput",
|
1049
1053
|
"TestSuiteRunMetricOutput",
|
1050
1054
|
"TestSuiteRunMetricStringOutput",
|
1055
|
+
"TestSuiteRunPromptSandboxExecConfigDataRequest",
|
1056
|
+
"TestSuiteRunPromptSandboxExecConfigRequest",
|
1051
1057
|
"TestSuiteRunPromptSandboxHistoryItemExecConfig",
|
1052
1058
|
"TestSuiteRunPromptSandboxHistoryItemExecConfigData",
|
1053
1059
|
"TestSuiteRunPromptSandboxHistoryItemExecConfigDataRequest",
|
@@ -1059,6 +1065,8 @@ __all__ = [
|
|
1059
1065
|
"TestSuiteRunWorkflowReleaseTagExecConfigData",
|
1060
1066
|
"TestSuiteRunWorkflowReleaseTagExecConfigDataRequest",
|
1061
1067
|
"TestSuiteRunWorkflowReleaseTagExecConfigRequest",
|
1068
|
+
"TestSuiteRunWorkflowSandboxExecConfigDataRequest",
|
1069
|
+
"TestSuiteRunWorkflowSandboxExecConfigRequest",
|
1062
1070
|
"TestSuiteRunWorkflowSandboxHistoryItemExecConfig",
|
1063
1071
|
"TestSuiteRunWorkflowSandboxHistoryItemExecConfigData",
|
1064
1072
|
"TestSuiteRunWorkflowSandboxHistoryItemExecConfigDataRequest",
|
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
|
|
4
4
|
import typing
|
5
5
|
from .finish_reason_enum import FinishReasonEnum
|
6
6
|
from .ml_model_usage import MlModelUsage
|
7
|
+
from .price import Price
|
7
8
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
9
|
import pydantic
|
9
10
|
|
@@ -16,6 +17,7 @@ class AdHocFulfilledPromptExecutionMeta(UniversalBaseModel):
|
|
16
17
|
latency: typing.Optional[int] = None
|
17
18
|
finish_reason: typing.Optional[FinishReasonEnum] = None
|
18
19
|
usage: typing.Optional[MlModelUsage] = None
|
20
|
+
cost: typing.Optional[Price] = None
|
19
21
|
|
20
22
|
if IS_PYDANTIC_V2:
|
21
23
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
|
|
4
4
|
import typing
|
5
5
|
from .finish_reason_enum import FinishReasonEnum
|
6
6
|
from .ml_model_usage import MlModelUsage
|
7
|
+
from .price import Price
|
7
8
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
9
|
import pydantic
|
9
10
|
|
@@ -16,6 +17,7 @@ class FulfilledPromptExecutionMeta(UniversalBaseModel):
|
|
16
17
|
latency: typing.Optional[int] = None
|
17
18
|
finish_reason: typing.Optional[FinishReasonEnum] = None
|
18
19
|
usage: typing.Optional[MlModelUsage] = None
|
20
|
+
cost: typing.Optional[Price] = None
|
19
21
|
|
20
22
|
if IS_PYDANTIC_V2:
|
21
23
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -2,10 +2,12 @@
|
|
2
2
|
|
3
3
|
import typing
|
4
4
|
from .test_suite_run_deployment_release_tag_exec_config_request import TestSuiteRunDeploymentReleaseTagExecConfigRequest
|
5
|
+
from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
|
5
6
|
from .test_suite_run_prompt_sandbox_history_item_exec_config_request import (
|
6
7
|
TestSuiteRunPromptSandboxHistoryItemExecConfigRequest,
|
7
8
|
)
|
8
9
|
from .test_suite_run_workflow_release_tag_exec_config_request import TestSuiteRunWorkflowReleaseTagExecConfigRequest
|
10
|
+
from .test_suite_run_workflow_sandbox_exec_config_request import TestSuiteRunWorkflowSandboxExecConfigRequest
|
9
11
|
from .test_suite_run_workflow_sandbox_history_item_exec_config_request import (
|
10
12
|
TestSuiteRunWorkflowSandboxHistoryItemExecConfigRequest,
|
11
13
|
)
|
@@ -13,8 +15,10 @@ from .test_suite_run_external_exec_config_request import TestSuiteRunExternalExe
|
|
13
15
|
|
14
16
|
TestSuiteRunExecConfigRequest = typing.Union[
|
15
17
|
TestSuiteRunDeploymentReleaseTagExecConfigRequest,
|
18
|
+
TestSuiteRunPromptSandboxExecConfigRequest,
|
16
19
|
TestSuiteRunPromptSandboxHistoryItemExecConfigRequest,
|
17
20
|
TestSuiteRunWorkflowReleaseTagExecConfigRequest,
|
21
|
+
TestSuiteRunWorkflowSandboxExecConfigRequest,
|
18
22
|
TestSuiteRunWorkflowSandboxHistoryItemExecConfigRequest,
|
19
23
|
TestSuiteRunExternalExecConfigRequest,
|
20
24
|
]
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import pydantic
|
5
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
6
|
+
import typing
|
7
|
+
|
8
|
+
|
9
|
+
class TestSuiteRunPromptSandboxExecConfigDataRequest(UniversalBaseModel):
|
10
|
+
prompt_sandbox_id: str = pydantic.Field()
|
11
|
+
"""
|
12
|
+
The ID of the Prompt Sandbox to run the Test Suite against.
|
13
|
+
"""
|
14
|
+
|
15
|
+
prompt_variant_id: str = pydantic.Field()
|
16
|
+
"""
|
17
|
+
The ID of the Prompt Variant within the Prompt Sandbox that you'd like to run the Test Suite against.
|
18
|
+
"""
|
19
|
+
|
20
|
+
if IS_PYDANTIC_V2:
|
21
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
22
|
+
else:
|
23
|
+
|
24
|
+
class Config:
|
25
|
+
frozen = True
|
26
|
+
smart_union = True
|
27
|
+
extra = pydantic.Extra.allow
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import typing
|
5
|
+
from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
|
6
|
+
import pydantic
|
7
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
|
+
|
9
|
+
|
10
|
+
class TestSuiteRunPromptSandboxExecConfigRequest(UniversalBaseModel):
|
11
|
+
"""
|
12
|
+
Execution configuration for running a Test Suite against a Prompt Sandbox
|
13
|
+
"""
|
14
|
+
|
15
|
+
type: typing.Literal["PROMPT_SANDBOX"] = "PROMPT_SANDBOX"
|
16
|
+
data: TestSuiteRunPromptSandboxExecConfigDataRequest
|
17
|
+
test_case_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
|
18
|
+
"""
|
19
|
+
Optionally specify a subset of test case ids to run. If not provided, all test cases within the test suite will be run by default.
|
20
|
+
"""
|
21
|
+
|
22
|
+
if IS_PYDANTIC_V2:
|
23
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
24
|
+
else:
|
25
|
+
|
26
|
+
class Config:
|
27
|
+
frozen = True
|
28
|
+
smart_union = True
|
29
|
+
extra = pydantic.Extra.allow
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import pydantic
|
5
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
6
|
+
import typing
|
7
|
+
|
8
|
+
|
9
|
+
class TestSuiteRunWorkflowSandboxExecConfigDataRequest(UniversalBaseModel):
|
10
|
+
workflow_sandbox_id: str = pydantic.Field()
|
11
|
+
"""
|
12
|
+
The ID of the Workflow Sandbox to run the Test Suite against.
|
13
|
+
"""
|
14
|
+
|
15
|
+
if IS_PYDANTIC_V2:
|
16
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
17
|
+
else:
|
18
|
+
|
19
|
+
class Config:
|
20
|
+
frozen = True
|
21
|
+
smart_union = True
|
22
|
+
extra = pydantic.Extra.allow
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import typing
|
5
|
+
from .test_suite_run_workflow_sandbox_exec_config_data_request import TestSuiteRunWorkflowSandboxExecConfigDataRequest
|
6
|
+
import pydantic
|
7
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
|
+
|
9
|
+
|
10
|
+
class TestSuiteRunWorkflowSandboxExecConfigRequest(UniversalBaseModel):
|
11
|
+
"""
|
12
|
+
Execution configuration for running a Test Suite against a Workflow Sandbox
|
13
|
+
"""
|
14
|
+
|
15
|
+
type: typing.Literal["WORKFLOW_SANDBOX"] = "WORKFLOW_SANDBOX"
|
16
|
+
data: TestSuiteRunWorkflowSandboxExecConfigDataRequest
|
17
|
+
test_case_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
|
18
|
+
"""
|
19
|
+
Optionally specify a subset of test case ids to run. If not provided, all test cases within the test suite will be run by default.
|
20
|
+
"""
|
21
|
+
|
22
|
+
if IS_PYDANTIC_V2:
|
23
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
24
|
+
else:
|
25
|
+
|
26
|
+
class Config:
|
27
|
+
frozen = True
|
28
|
+
smart_union = True
|
29
|
+
extra = pydantic.Extra.allow
|
vellum/plugins/pydantic.py
CHANGED
@@ -53,7 +53,7 @@ class OnValidatePython(ValidatePythonHandlerProtocol):
|
|
53
53
|
return
|
54
54
|
|
55
55
|
if self_instance:
|
56
|
-
model_fields: Dict[str, FieldInfo] = self_instance.model_fields
|
56
|
+
model_fields: Dict[str, FieldInfo] = self_instance.__class__.model_fields
|
57
57
|
else:
|
58
58
|
model_fields = {}
|
59
59
|
|
vellum/workflows/events/node.py
CHANGED
@@ -8,9 +8,10 @@ from vellum.workflows.expressions.accessor import AccessorExpression
|
|
8
8
|
from vellum.workflows.outputs.base import BaseOutput
|
9
9
|
from vellum.workflows.ports.port import Port
|
10
10
|
from vellum.workflows.references.node import NodeReference
|
11
|
+
from vellum.workflows.types.definition import serialize_type_encoder_with_id
|
11
12
|
from vellum.workflows.types.generics import OutputsType
|
12
13
|
|
13
|
-
from .types import BaseEvent, default_serializer
|
14
|
+
from .types import BaseEvent, default_serializer
|
14
15
|
|
15
16
|
if TYPE_CHECKING:
|
16
17
|
from vellum.workflows.nodes.bases import BaseNode
|
vellum/workflows/events/types.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
1
|
from datetime import datetime
|
2
2
|
import json
|
3
3
|
from uuid import UUID, uuid4
|
4
|
-
from typing import Annotated, Any,
|
4
|
+
from typing import Annotated, Any, Literal, Optional, Union, get_args
|
5
5
|
|
6
|
-
from pydantic import
|
6
|
+
from pydantic import Field, GetCoreSchemaHandler, Tag, ValidationInfo
|
7
7
|
from pydantic_core import CoreSchema, core_schema
|
8
8
|
|
9
9
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
10
10
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
11
|
+
from vellum.workflows.types.definition import VellumCodeResourceDefinition
|
11
12
|
from vellum.workflows.types.utils import datetime_now
|
12
13
|
|
13
14
|
|
@@ -19,28 +20,6 @@ def default_datetime_factory() -> datetime:
|
|
19
20
|
return datetime_now()
|
20
21
|
|
21
22
|
|
22
|
-
excluded_modules = {"typing", "builtins"}
|
23
|
-
|
24
|
-
|
25
|
-
def serialize_type_encoder(obj: type) -> Dict[str, Any]:
|
26
|
-
return {
|
27
|
-
"name": obj.__name__,
|
28
|
-
"module": obj.__module__.split("."),
|
29
|
-
}
|
30
|
-
|
31
|
-
|
32
|
-
def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
|
33
|
-
if hasattr(obj, "__id__") and isinstance(obj, type):
|
34
|
-
return {
|
35
|
-
"id": getattr(obj, "__id__"),
|
36
|
-
**serialize_type_encoder(obj),
|
37
|
-
}
|
38
|
-
elif isinstance(obj, CodeResourceDefinition):
|
39
|
-
return obj.model_dump(mode="json")
|
40
|
-
|
41
|
-
raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
|
42
|
-
|
43
|
-
|
44
23
|
def default_serializer(obj: Any) -> Any:
|
45
24
|
return json.loads(
|
46
25
|
json.dumps(
|
@@ -50,22 +29,6 @@ def default_serializer(obj: Any) -> Any:
|
|
50
29
|
)
|
51
30
|
|
52
31
|
|
53
|
-
class CodeResourceDefinition(UniversalBaseModel):
|
54
|
-
id: UUID
|
55
|
-
name: str
|
56
|
-
module: List[str]
|
57
|
-
|
58
|
-
@staticmethod
|
59
|
-
def encode(obj: type) -> "CodeResourceDefinition":
|
60
|
-
return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
|
61
|
-
|
62
|
-
|
63
|
-
VellumCodeResourceDefinition = Annotated[
|
64
|
-
CodeResourceDefinition,
|
65
|
-
BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
|
66
|
-
]
|
67
|
-
|
68
|
-
|
69
32
|
class BaseParentContext(UniversalBaseModel):
|
70
33
|
span_id: UUID
|
71
34
|
parent: Optional["ParentContext"] = None
|
@@ -8,6 +8,7 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
|
|
8
8
|
from vellum.workflows.errors import WorkflowError
|
9
9
|
from vellum.workflows.outputs.base import BaseOutput
|
10
10
|
from vellum.workflows.references import ExternalInputReference
|
11
|
+
from vellum.workflows.types.definition import serialize_type_encoder_with_id
|
11
12
|
from vellum.workflows.types.generics import InputsType, OutputsType, StateType
|
12
13
|
|
13
14
|
from .node import (
|
@@ -18,7 +19,7 @@ from .node import (
|
|
18
19
|
NodeExecutionResumedEvent,
|
19
20
|
NodeExecutionStreamingEvent,
|
20
21
|
)
|
21
|
-
from .types import BaseEvent, default_serializer
|
22
|
+
from .types import BaseEvent, default_serializer
|
22
23
|
|
23
24
|
if TYPE_CHECKING:
|
24
25
|
from vellum.workflows.workflows.base import BaseWorkflow
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from uuid import UUID
|
3
|
-
from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Sequence, Union
|
3
|
+
from typing import Any, ClassVar, Dict, Generator, Generic, Iterator, List, Optional, Sequence, Set, Union
|
4
4
|
|
5
5
|
from vellum import (
|
6
6
|
ChatHistoryInputRequest,
|
@@ -9,17 +9,20 @@ from vellum import (
|
|
9
9
|
JsonInputRequest,
|
10
10
|
PromptDeploymentExpandMetaRequest,
|
11
11
|
PromptDeploymentInputRequest,
|
12
|
+
PromptOutput,
|
12
13
|
RawPromptExecutionOverridesRequest,
|
13
14
|
StringInputRequest,
|
14
15
|
)
|
15
|
-
from vellum.client import RequestOptions
|
16
|
+
from vellum.client import ApiError, RequestOptions
|
16
17
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
17
18
|
from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
|
18
19
|
from vellum.workflows.context import get_execution_context
|
19
20
|
from vellum.workflows.errors import WorkflowErrorCode
|
21
|
+
from vellum.workflows.errors.types import vellum_error_to_workflow_error
|
20
22
|
from vellum.workflows.events.types import default_serializer
|
21
23
|
from vellum.workflows.exceptions import NodeException
|
22
24
|
from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
|
25
|
+
from vellum.workflows.outputs import BaseOutput
|
23
26
|
from vellum.workflows.types import MergeBehavior
|
24
27
|
from vellum.workflows.types.generics import StateType
|
25
28
|
|
@@ -56,13 +59,21 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
56
59
|
class Trigger(BasePromptNode.Trigger):
|
57
60
|
merge_behavior = MergeBehavior.AWAIT_ANY
|
58
61
|
|
59
|
-
def _get_prompt_event_stream(self) -> Iterator[ExecutePromptEvent]:
|
62
|
+
def _get_prompt_event_stream(self, ml_model_fallback: Optional[str] = None) -> Iterator[ExecutePromptEvent]:
|
60
63
|
execution_context = get_execution_context()
|
61
64
|
request_options = self.request_options or RequestOptions()
|
62
65
|
request_options["additional_body_parameters"] = {
|
63
66
|
"execution_context": execution_context.model_dump(mode="json"),
|
64
67
|
**request_options.get("additional_body_parameters", {}),
|
65
68
|
}
|
69
|
+
if ml_model_fallback:
|
70
|
+
request_options["additional_body_parameters"] = {
|
71
|
+
"overrides": {
|
72
|
+
"ml_model_fallback": ml_model_fallback,
|
73
|
+
},
|
74
|
+
**request_options.get("additional_body_parameters", {}),
|
75
|
+
}
|
76
|
+
|
66
77
|
return self._context.vellum_client.execute_prompt_stream(
|
67
78
|
inputs=self._compile_prompt_inputs(),
|
68
79
|
prompt_deployment_id=str(self.deployment) if isinstance(self.deployment, UUID) else None,
|
@@ -76,6 +87,86 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
76
87
|
request_options=request_options,
|
77
88
|
)
|
78
89
|
|
90
|
+
def _process_prompt_event_stream(
|
91
|
+
self,
|
92
|
+
prompt_event_stream: Optional[Iterator[ExecutePromptEvent]] = None,
|
93
|
+
tried_fallbacks: Optional[set[str]] = None,
|
94
|
+
) -> Generator[BaseOutput, None, Optional[List[PromptOutput]]]:
|
95
|
+
"""Override the base prompt node _process_prompt_event_stream()"""
|
96
|
+
self._validate()
|
97
|
+
|
98
|
+
if tried_fallbacks is None:
|
99
|
+
tried_fallbacks = set()
|
100
|
+
|
101
|
+
if prompt_event_stream is None:
|
102
|
+
try:
|
103
|
+
prompt_event_stream = self._get_prompt_event_stream()
|
104
|
+
next(prompt_event_stream)
|
105
|
+
except ApiError as e:
|
106
|
+
if (
|
107
|
+
e.status_code
|
108
|
+
and e.status_code < 500
|
109
|
+
and self.ml_model_fallbacks is not OMIT
|
110
|
+
and self.ml_model_fallbacks is not None
|
111
|
+
):
|
112
|
+
prompt_event_stream = self._retry_prompt_stream_with_fallbacks(tried_fallbacks)
|
113
|
+
else:
|
114
|
+
self._handle_api_error(e)
|
115
|
+
|
116
|
+
outputs: Optional[List[PromptOutput]] = None
|
117
|
+
if prompt_event_stream is not None:
|
118
|
+
for event in prompt_event_stream:
|
119
|
+
if event.state == "INITIATED":
|
120
|
+
continue
|
121
|
+
elif event.state == "STREAMING":
|
122
|
+
yield BaseOutput(name="results", delta=event.output.value)
|
123
|
+
elif event.state == "FULFILLED":
|
124
|
+
outputs = event.outputs
|
125
|
+
yield BaseOutput(name="results", value=event.outputs)
|
126
|
+
elif event.state == "REJECTED":
|
127
|
+
if (
|
128
|
+
event.error
|
129
|
+
and event.error.code == WorkflowErrorCode.PROVIDER_ERROR.value
|
130
|
+
and self.ml_model_fallbacks is not OMIT
|
131
|
+
and self.ml_model_fallbacks is not None
|
132
|
+
):
|
133
|
+
try:
|
134
|
+
fallback_stream = self._retry_prompt_stream_with_fallbacks(tried_fallbacks)
|
135
|
+
fallback_outputs = yield from self._process_prompt_event_stream(
|
136
|
+
fallback_stream, tried_fallbacks
|
137
|
+
)
|
138
|
+
return fallback_outputs
|
139
|
+
except ApiError:
|
140
|
+
pass
|
141
|
+
|
142
|
+
workflow_error = vellum_error_to_workflow_error(event.error)
|
143
|
+
raise NodeException.of(workflow_error)
|
144
|
+
|
145
|
+
return outputs
|
146
|
+
|
147
|
+
def _retry_prompt_stream_with_fallbacks(self, tried_fallbacks: Set[str]) -> Optional[Iterator[ExecutePromptEvent]]:
|
148
|
+
if self.ml_model_fallbacks is not None:
|
149
|
+
for ml_model_fallback in self.ml_model_fallbacks:
|
150
|
+
if ml_model_fallback in tried_fallbacks:
|
151
|
+
continue
|
152
|
+
|
153
|
+
try:
|
154
|
+
tried_fallbacks.add(ml_model_fallback)
|
155
|
+
prompt_event_stream = self._get_prompt_event_stream(ml_model_fallback=ml_model_fallback)
|
156
|
+
next(prompt_event_stream)
|
157
|
+
return prompt_event_stream
|
158
|
+
except ApiError:
|
159
|
+
continue
|
160
|
+
else:
|
161
|
+
self._handle_api_error(
|
162
|
+
ApiError(
|
163
|
+
body={"detail": f"Failed to execute prompts with these fallbacks: {self.ml_model_fallbacks}"},
|
164
|
+
status_code=400,
|
165
|
+
)
|
166
|
+
)
|
167
|
+
|
168
|
+
return None
|
169
|
+
|
79
170
|
def _compile_prompt_inputs(self) -> List[PromptDeploymentInputRequest]:
|
80
171
|
# TODO: We may want to consolidate with subworkflow deployment input compilation
|
81
172
|
# https://app.shortcut.com/vellum/story/4117
|
@@ -1,12 +1,8 @@
|
|
1
1
|
import pytest
|
2
2
|
from uuid import UUID
|
3
3
|
|
4
|
-
from vellum.workflows.events.types import
|
5
|
-
|
6
|
-
NodeParentContext,
|
7
|
-
WorkflowDeploymentParentContext,
|
8
|
-
WorkflowParentContext,
|
9
|
-
)
|
4
|
+
from vellum.workflows.events.types import NodeParentContext, WorkflowDeploymentParentContext, WorkflowParentContext
|
5
|
+
from vellum.workflows.types.definition import CodeResourceDefinition
|
10
6
|
|
11
7
|
|
12
8
|
@pytest.fixture
|
@@ -117,7 +117,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
117
117
|
value=cast(Dict[str, Any], input_value),
|
118
118
|
)
|
119
119
|
)
|
120
|
-
elif isinstance(input_value, float):
|
120
|
+
elif isinstance(input_value, (int, float)):
|
121
121
|
compiled_inputs.append(
|
122
122
|
NumberInput(
|
123
123
|
name=input_name,
|
File without changes
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from vellum import TestSuiteRunMetricNumberOutput
|
2
|
+
from vellum.client.types.chat_history_input import ChatHistoryInput
|
3
|
+
from vellum.client.types.chat_message import ChatMessage
|
4
|
+
from vellum.client.types.json_input import JsonInput
|
5
|
+
from vellum.client.types.metric_definition_execution import MetricDefinitionExecution
|
6
|
+
from vellum.client.types.number_input import NumberInput
|
7
|
+
from vellum.client.types.string_input import StringInput
|
8
|
+
from vellum.workflows.nodes.displayable.guardrail_node.node import GuardrailNode
|
9
|
+
|
10
|
+
|
11
|
+
def test_guardrail_node__inputs(vellum_client):
|
12
|
+
"""Test that GuardrailNode correctly handles inputs."""
|
13
|
+
|
14
|
+
# GIVEN a Guardrail Node with inputs
|
15
|
+
class MyGuard(GuardrailNode):
|
16
|
+
metric_definition = "example_metric_definition"
|
17
|
+
metric_inputs = {
|
18
|
+
"a_string": "hello",
|
19
|
+
"a_chat_history": [ChatMessage(role="USER", text="Hello, how are you?")],
|
20
|
+
"a_dict": {"foo": "bar"},
|
21
|
+
"a_int": 42,
|
22
|
+
"a_float": 3.14,
|
23
|
+
}
|
24
|
+
|
25
|
+
vellum_client.metric_definitions.execute_metric_definition.return_value = MetricDefinitionExecution(
|
26
|
+
outputs=[
|
27
|
+
TestSuiteRunMetricNumberOutput(
|
28
|
+
name="score",
|
29
|
+
value=1.0,
|
30
|
+
),
|
31
|
+
],
|
32
|
+
)
|
33
|
+
|
34
|
+
# WHEN the node is run
|
35
|
+
MyGuard().run()
|
36
|
+
|
37
|
+
# THEN the metric_definitions.execute_metric_definition method should be called with the correct inputs
|
38
|
+
mock_api = vellum_client.metric_definitions.execute_metric_definition
|
39
|
+
assert mock_api.call_count == 1
|
40
|
+
|
41
|
+
assert mock_api.call_args.kwargs["inputs"] == [
|
42
|
+
StringInput(name="a_string", type="STRING", value="hello"),
|
43
|
+
ChatHistoryInput(
|
44
|
+
name="a_chat_history", type="CHAT_HISTORY", value=[ChatMessage(role="USER", text="Hello, how are you?")]
|
45
|
+
),
|
46
|
+
JsonInput(name="a_dict", type="JSON", value={"foo": "bar"}),
|
47
|
+
NumberInput(name="a_int", type="NUMBER", value=42.0),
|
48
|
+
NumberInput(name="a_float", type="NUMBER", value=3.14),
|
49
|
+
]
|
50
|
+
assert len(mock_api.call_args.kwargs["inputs"]) == 5
|