vellum-ai 0.14.37__py3-none-any.whl → 0.14.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vellum/__init__.py +10 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/reference.md +6272 -0
  4. vellum/client/types/__init__.py +10 -0
  5. vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
  6. vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
  7. vellum/client/types/test_suite_run_exec_config_request.py +4 -0
  8. vellum/client/types/test_suite_run_progress.py +20 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
  10. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
  11. vellum/client/types/test_suite_run_read.py +3 -0
  12. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
  13. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
  14. vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
  15. vellum/client/types/workflow_execution_event_error_code.py +1 -0
  16. vellum/plugins/pydantic.py +1 -1
  17. vellum/types/test_suite_run_progress.py +3 -0
  18. vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
  19. vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
  20. vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
  21. vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
  22. vellum/workflows/errors/types.py +1 -0
  23. vellum/workflows/events/node.py +2 -1
  24. vellum/workflows/events/tests/test_event.py +1 -0
  25. vellum/workflows/events/types.py +3 -40
  26. vellum/workflows/events/workflow.py +15 -4
  27. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
  28. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
  29. vellum/workflows/nodes/displayable/conftest.py +2 -6
  30. vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
  31. vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
  32. vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
  33. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
  34. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +323 -0
  35. vellum/workflows/runner/runner.py +78 -57
  36. vellum/workflows/state/base.py +177 -50
  37. vellum/workflows/state/tests/test_state.py +26 -20
  38. vellum/workflows/types/definition.py +71 -0
  39. vellum/workflows/types/generics.py +34 -1
  40. vellum/workflows/workflows/base.py +26 -19
  41. vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
  42. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/METADATA +1 -1
  43. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/RECORD +49 -35
  44. vellum_cli/push.py +2 -3
  45. vellum_cli/tests/test_push.py +52 -0
  46. vellum_ee/workflows/display/vellum.py +0 -5
  47. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/LICENSE +0 -0
  48. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/WHEEL +0 -0
  49. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/entry_points.txt +0 -0
@@ -453,6 +453,9 @@ from .test_suite_run_metric_json_output import TestSuiteRunMetricJsonOutput
453
453
  from .test_suite_run_metric_number_output import TestSuiteRunMetricNumberOutput
454
454
  from .test_suite_run_metric_output import TestSuiteRunMetricOutput
455
455
  from .test_suite_run_metric_string_output import TestSuiteRunMetricStringOutput
456
+ from .test_suite_run_progress import TestSuiteRunProgress
457
+ from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
458
+ from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
456
459
  from .test_suite_run_prompt_sandbox_history_item_exec_config import TestSuiteRunPromptSandboxHistoryItemExecConfig
457
460
  from .test_suite_run_prompt_sandbox_history_item_exec_config_data import (
458
461
  TestSuiteRunPromptSandboxHistoryItemExecConfigData,
@@ -472,6 +475,8 @@ from .test_suite_run_workflow_release_tag_exec_config_data_request import (
472
475
  TestSuiteRunWorkflowReleaseTagExecConfigDataRequest,
473
476
  )
474
477
  from .test_suite_run_workflow_release_tag_exec_config_request import TestSuiteRunWorkflowReleaseTagExecConfigRequest
478
+ from .test_suite_run_workflow_sandbox_exec_config_data_request import TestSuiteRunWorkflowSandboxExecConfigDataRequest
479
+ from .test_suite_run_workflow_sandbox_exec_config_request import TestSuiteRunWorkflowSandboxExecConfigRequest
475
480
  from .test_suite_run_workflow_sandbox_history_item_exec_config import TestSuiteRunWorkflowSandboxHistoryItemExecConfig
476
481
  from .test_suite_run_workflow_sandbox_history_item_exec_config_data import (
477
482
  TestSuiteRunWorkflowSandboxHistoryItemExecConfigData,
@@ -1048,6 +1053,9 @@ __all__ = [
1048
1053
  "TestSuiteRunMetricNumberOutput",
1049
1054
  "TestSuiteRunMetricOutput",
1050
1055
  "TestSuiteRunMetricStringOutput",
1056
+ "TestSuiteRunProgress",
1057
+ "TestSuiteRunPromptSandboxExecConfigDataRequest",
1058
+ "TestSuiteRunPromptSandboxExecConfigRequest",
1051
1059
  "TestSuiteRunPromptSandboxHistoryItemExecConfig",
1052
1060
  "TestSuiteRunPromptSandboxHistoryItemExecConfigData",
1053
1061
  "TestSuiteRunPromptSandboxHistoryItemExecConfigDataRequest",
@@ -1059,6 +1067,8 @@ __all__ = [
1059
1067
  "TestSuiteRunWorkflowReleaseTagExecConfigData",
1060
1068
  "TestSuiteRunWorkflowReleaseTagExecConfigDataRequest",
1061
1069
  "TestSuiteRunWorkflowReleaseTagExecConfigRequest",
1070
+ "TestSuiteRunWorkflowSandboxExecConfigDataRequest",
1071
+ "TestSuiteRunWorkflowSandboxExecConfigRequest",
1062
1072
  "TestSuiteRunWorkflowSandboxHistoryItemExecConfig",
1063
1073
  "TestSuiteRunWorkflowSandboxHistoryItemExecConfigData",
1064
1074
  "TestSuiteRunWorkflowSandboxHistoryItemExecConfigDataRequest",
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .finish_reason_enum import FinishReasonEnum
6
6
  from .ml_model_usage import MlModelUsage
7
+ from .price import Price
7
8
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
9
  import pydantic
9
10
 
@@ -16,6 +17,7 @@ class AdHocFulfilledPromptExecutionMeta(UniversalBaseModel):
16
17
  latency: typing.Optional[int] = None
17
18
  finish_reason: typing.Optional[FinishReasonEnum] = None
18
19
  usage: typing.Optional[MlModelUsage] = None
20
+ cost: typing.Optional[Price] = None
19
21
 
20
22
  if IS_PYDANTIC_V2:
21
23
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing
5
5
  from .finish_reason_enum import FinishReasonEnum
6
6
  from .ml_model_usage import MlModelUsage
7
+ from .price import Price
7
8
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
9
  import pydantic
9
10
 
@@ -16,6 +17,7 @@ class FulfilledPromptExecutionMeta(UniversalBaseModel):
16
17
  latency: typing.Optional[int] = None
17
18
  finish_reason: typing.Optional[FinishReasonEnum] = None
18
19
  usage: typing.Optional[MlModelUsage] = None
20
+ cost: typing.Optional[Price] = None
19
21
 
20
22
  if IS_PYDANTIC_V2:
21
23
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -2,10 +2,12 @@
2
2
 
3
3
  import typing
4
4
  from .test_suite_run_deployment_release_tag_exec_config_request import TestSuiteRunDeploymentReleaseTagExecConfigRequest
5
+ from .test_suite_run_prompt_sandbox_exec_config_request import TestSuiteRunPromptSandboxExecConfigRequest
5
6
  from .test_suite_run_prompt_sandbox_history_item_exec_config_request import (
6
7
  TestSuiteRunPromptSandboxHistoryItemExecConfigRequest,
7
8
  )
8
9
  from .test_suite_run_workflow_release_tag_exec_config_request import TestSuiteRunWorkflowReleaseTagExecConfigRequest
10
+ from .test_suite_run_workflow_sandbox_exec_config_request import TestSuiteRunWorkflowSandboxExecConfigRequest
9
11
  from .test_suite_run_workflow_sandbox_history_item_exec_config_request import (
10
12
  TestSuiteRunWorkflowSandboxHistoryItemExecConfigRequest,
11
13
  )
@@ -13,8 +15,10 @@ from .test_suite_run_external_exec_config_request import TestSuiteRunExternalExe
13
15
 
14
16
  TestSuiteRunExecConfigRequest = typing.Union[
15
17
  TestSuiteRunDeploymentReleaseTagExecConfigRequest,
18
+ TestSuiteRunPromptSandboxExecConfigRequest,
16
19
  TestSuiteRunPromptSandboxHistoryItemExecConfigRequest,
17
20
  TestSuiteRunWorkflowReleaseTagExecConfigRequest,
21
+ TestSuiteRunWorkflowSandboxExecConfigRequest,
18
22
  TestSuiteRunWorkflowSandboxHistoryItemExecConfigRequest,
19
23
  TestSuiteRunExternalExecConfigRequest,
20
24
  ]
@@ -0,0 +1,20 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
5
+ import typing
6
+ import pydantic
7
+
8
+
9
+ class TestSuiteRunProgress(UniversalBaseModel):
10
+ number_of_requested_test_cases: int
11
+ number_of_completed_test_cases: int
12
+
13
+ if IS_PYDANTIC_V2:
14
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
15
+ else:
16
+
17
+ class Config:
18
+ frozen = True
19
+ smart_union = True
20
+ extra = pydantic.Extra.allow
@@ -0,0 +1,27 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import pydantic
5
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
6
+ import typing
7
+
8
+
9
+ class TestSuiteRunPromptSandboxExecConfigDataRequest(UniversalBaseModel):
10
+ prompt_sandbox_id: str = pydantic.Field()
11
+ """
12
+ The ID of the Prompt Sandbox to run the Test Suite against.
13
+ """
14
+
15
+ prompt_variant_id: str = pydantic.Field()
16
+ """
17
+ The ID of the Prompt Variant within the Prompt Sandbox that you'd like to run the Test Suite against.
18
+ """
19
+
20
+ if IS_PYDANTIC_V2:
21
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
22
+ else:
23
+
24
+ class Config:
25
+ frozen = True
26
+ smart_union = True
27
+ extra = pydantic.Extra.allow
@@ -0,0 +1,29 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import typing
5
+ from .test_suite_run_prompt_sandbox_exec_config_data_request import TestSuiteRunPromptSandboxExecConfigDataRequest
6
+ import pydantic
7
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
+
9
+
10
+ class TestSuiteRunPromptSandboxExecConfigRequest(UniversalBaseModel):
11
+ """
12
+ Execution configuration for running a Test Suite against a Prompt Sandbox
13
+ """
14
+
15
+ type: typing.Literal["PROMPT_SANDBOX"] = "PROMPT_SANDBOX"
16
+ data: TestSuiteRunPromptSandboxExecConfigDataRequest
17
+ test_case_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
18
+ """
19
+ Optionally specify a subset of test case ids to run. If not provided, all test cases within the test suite will be run by default.
20
+ """
21
+
22
+ if IS_PYDANTIC_V2:
23
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
24
+ else:
25
+
26
+ class Config:
27
+ frozen = True
28
+ smart_union = True
29
+ extra = pydantic.Extra.allow
@@ -9,6 +9,7 @@ from .test_suite_run_state import TestSuiteRunState
9
9
  import pydantic
10
10
  import typing
11
11
  from .test_suite_run_exec_config import TestSuiteRunExecConfig
12
+ from .test_suite_run_progress import TestSuiteRunProgress
12
13
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
13
14
  from ..core.pydantic_utilities import update_forward_refs
14
15
 
@@ -33,6 +34,8 @@ class TestSuiteRunRead(UniversalBaseModel):
33
34
  Configuration that defines how the Test Suite should be run
34
35
  """
35
36
 
37
+ progress: typing.Optional[TestSuiteRunProgress] = None
38
+
36
39
  if IS_PYDANTIC_V2:
37
40
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
38
41
  else:
@@ -0,0 +1,22 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import pydantic
5
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
6
+ import typing
7
+
8
+
9
+ class TestSuiteRunWorkflowSandboxExecConfigDataRequest(UniversalBaseModel):
10
+ workflow_sandbox_id: str = pydantic.Field()
11
+ """
12
+ The ID of the Workflow Sandbox to run the Test Suite against.
13
+ """
14
+
15
+ if IS_PYDANTIC_V2:
16
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
17
+ else:
18
+
19
+ class Config:
20
+ frozen = True
21
+ smart_union = True
22
+ extra = pydantic.Extra.allow
@@ -0,0 +1,29 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import typing
5
+ from .test_suite_run_workflow_sandbox_exec_config_data_request import TestSuiteRunWorkflowSandboxExecConfigDataRequest
6
+ import pydantic
7
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
+
9
+
10
+ class TestSuiteRunWorkflowSandboxExecConfigRequest(UniversalBaseModel):
11
+ """
12
+ Execution configuration for running a Test Suite against a Workflow Sandbox
13
+ """
14
+
15
+ type: typing.Literal["WORKFLOW_SANDBOX"] = "WORKFLOW_SANDBOX"
16
+ data: TestSuiteRunWorkflowSandboxExecConfigDataRequest
17
+ test_case_ids: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
18
+ """
19
+ Optionally specify a subset of test case ids to run. If not provided, all test cases within the test suite will be run by default.
20
+ """
21
+
22
+ if IS_PYDANTIC_V2:
23
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
24
+ else:
25
+
26
+ class Config:
27
+ frozen = True
28
+ smart_union = True
29
+ extra = pydantic.Extra.allow
@@ -11,6 +11,7 @@ VellumSdkErrorCodeEnum = typing.Union[
11
11
  "INVALID_CODE",
12
12
  "INVALID_TEMPLATE",
13
13
  "INTERNAL_ERROR",
14
+ "PROVIDER_CREDENTIALS_UNAVAILABLE",
14
15
  "PROVIDER_ERROR",
15
16
  "USER_DEFINED_ERROR",
16
17
  "WORKFLOW_CANCELLED",
@@ -6,6 +6,7 @@ WorkflowExecutionEventErrorCode = typing.Union[
6
6
  typing.Literal[
7
7
  "WORKFLOW_INITIALIZATION",
8
8
  "WORKFLOW_CANCELLED",
9
+ "PROVIDER_CREDENTIALS_UNAVAILABLE",
9
10
  "NODE_EXECUTION_COUNT_LIMIT_REACHED",
10
11
  "INTERNAL_SERVER_ERROR",
11
12
  "NODE_EXECUTION",
@@ -53,7 +53,7 @@ class OnValidatePython(ValidatePythonHandlerProtocol):
53
53
  return
54
54
 
55
55
  if self_instance:
56
- model_fields: Dict[str, FieldInfo] = self_instance.model_fields
56
+ model_fields: Dict[str, FieldInfo] = self_instance.__class__.model_fields
57
57
  else:
58
58
  model_fields = {}
59
59
 
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_progress import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_prompt_sandbox_exec_config_data_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_prompt_sandbox_exec_config_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_workflow_sandbox_exec_config_data_request import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.test_suite_run_workflow_sandbox_exec_config_request import *
@@ -17,6 +17,7 @@ class WorkflowErrorCode(Enum):
17
17
  INVALID_TEMPLATE = "INVALID_TEMPLATE"
18
18
  INTERNAL_ERROR = "INTERNAL_ERROR"
19
19
  NODE_EXECUTION = "NODE_EXECUTION"
20
+ PROVIDER_CREDENTIALS_UNAVAILABLE = "PROVIDER_CREDENTIALS_UNAVAILABLE"
20
21
  PROVIDER_ERROR = "PROVIDER_ERROR"
21
22
  USER_DEFINED_ERROR = "USER_DEFINED_ERROR"
22
23
  WORKFLOW_CANCELLED = "WORKFLOW_CANCELLED"
@@ -8,9 +8,10 @@ from vellum.workflows.expressions.accessor import AccessorExpression
8
8
  from vellum.workflows.outputs.base import BaseOutput
9
9
  from vellum.workflows.ports.port import Port
10
10
  from vellum.workflows.references.node import NodeReference
11
+ from vellum.workflows.types.definition import serialize_type_encoder_with_id
11
12
  from vellum.workflows.types.generics import OutputsType
12
13
 
13
- from .types import BaseEvent, default_serializer, serialize_type_encoder_with_id
14
+ from .types import BaseEvent, default_serializer
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  from vellum.workflows.nodes.bases import BaseNode
@@ -89,6 +89,7 @@ mock_node_uuid = str(uuid4_from_hash(MockNode.__qualname__))
89
89
  "foo": "bar",
90
90
  },
91
91
  "display_context": None,
92
+ "initial_state": None,
92
93
  },
93
94
  "parent": None,
94
95
  },
@@ -1,13 +1,14 @@
1
1
  from datetime import datetime
2
2
  import json
3
3
  from uuid import UUID, uuid4
4
- from typing import Annotated, Any, Dict, List, Literal, Optional, Union, get_args
4
+ from typing import Annotated, Any, Literal, Optional, Union, get_args
5
5
 
6
- from pydantic import BeforeValidator, Field, GetCoreSchemaHandler, Tag, ValidationInfo
6
+ from pydantic import Field, GetCoreSchemaHandler, Tag, ValidationInfo
7
7
  from pydantic_core import CoreSchema, core_schema
8
8
 
9
9
  from vellum.core.pydantic_utilities import UniversalBaseModel
10
10
  from vellum.workflows.state.encoder import DefaultStateEncoder
11
+ from vellum.workflows.types.definition import VellumCodeResourceDefinition
11
12
  from vellum.workflows.types.utils import datetime_now
12
13
 
13
14
 
@@ -19,28 +20,6 @@ def default_datetime_factory() -> datetime:
19
20
  return datetime_now()
20
21
 
21
22
 
22
- excluded_modules = {"typing", "builtins"}
23
-
24
-
25
- def serialize_type_encoder(obj: type) -> Dict[str, Any]:
26
- return {
27
- "name": obj.__name__,
28
- "module": obj.__module__.split("."),
29
- }
30
-
31
-
32
- def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
33
- if hasattr(obj, "__id__") and isinstance(obj, type):
34
- return {
35
- "id": getattr(obj, "__id__"),
36
- **serialize_type_encoder(obj),
37
- }
38
- elif isinstance(obj, CodeResourceDefinition):
39
- return obj.model_dump(mode="json")
40
-
41
- raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
42
-
43
-
44
23
  def default_serializer(obj: Any) -> Any:
45
24
  return json.loads(
46
25
  json.dumps(
@@ -50,22 +29,6 @@ def default_serializer(obj: Any) -> Any:
50
29
  )
51
30
 
52
31
 
53
- class CodeResourceDefinition(UniversalBaseModel):
54
- id: UUID
55
- name: str
56
- module: List[str]
57
-
58
- @staticmethod
59
- def encode(obj: type) -> "CodeResourceDefinition":
60
- return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
61
-
62
-
63
- VellumCodeResourceDefinition = Annotated[
64
- CodeResourceDefinition,
65
- BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
66
- ]
67
-
68
-
69
32
  class BaseParentContext(UniversalBaseModel):
70
33
  span_id: UUID
71
34
  parent: Optional["ParentContext"] = None
@@ -8,6 +8,7 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
8
8
  from vellum.workflows.errors import WorkflowError
9
9
  from vellum.workflows.outputs.base import BaseOutput
10
10
  from vellum.workflows.references import ExternalInputReference
11
+ from vellum.workflows.types.definition import serialize_type_encoder_with_id
11
12
  from vellum.workflows.types.generics import InputsType, OutputsType, StateType
12
13
 
13
14
  from .node import (
@@ -18,7 +19,7 @@ from .node import (
18
19
  NodeExecutionResumedEvent,
19
20
  NodeExecutionStreamingEvent,
20
21
  )
21
- from .types import BaseEvent, default_serializer, serialize_type_encoder_with_id
22
+ from .types import BaseEvent, default_serializer
22
23
 
23
24
  if TYPE_CHECKING:
24
25
  from vellum.workflows.workflows.base import BaseWorkflow
@@ -53,8 +54,10 @@ class WorkflowEventDisplayContext(UniversalBaseModel):
53
54
  workflow_outputs: Dict[str, UUID]
54
55
 
55
56
 
56
- class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsType]):
57
+ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsType, StateType]):
57
58
  inputs: InputsType
59
+ initial_state: Optional[StateType] = None
60
+
58
61
  # It is still the responsibility of the workflow server to populate this context. The SDK's
59
62
  # Workflow Runner will always leave this field None.
60
63
  #
@@ -66,15 +69,23 @@ class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsT
66
69
  def serialize_inputs(self, inputs: InputsType, _info: Any) -> Dict[str, Any]:
67
70
  return default_serializer(inputs)
68
71
 
72
+ @field_serializer("initial_state")
73
+ def serialize_initial_state(self, initial_state: Optional[StateType], _info: Any) -> Optional[Dict[str, Any]]:
74
+ return default_serializer(initial_state)
75
+
69
76
 
70
- class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType]):
77
+ class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType, StateType]):
71
78
  name: Literal["workflow.execution.initiated"] = "workflow.execution.initiated"
72
- body: WorkflowExecutionInitiatedBody[InputsType]
79
+ body: WorkflowExecutionInitiatedBody[InputsType, StateType]
73
80
 
74
81
  @property
75
82
  def inputs(self) -> InputsType:
76
83
  return self.body.inputs
77
84
 
85
+ @property
86
+ def initial_state(self) -> Optional[StateType]:
87
+ return self.body.initial_state
88
+
78
89
 
79
90
  class WorkflowExecutionStreamingBody(_BaseWorkflowExecutionBody):
80
91
  output: BaseOutput
@@ -69,7 +69,13 @@ class BasePromptNode(BaseNode, Generic[StateType]):
69
69
  return outputs
70
70
 
71
71
  def _handle_api_error(self, e: ApiError):
72
- if e.status_code and e.status_code >= 400 and e.status_code < 500 and isinstance(e.body, dict):
72
+ if e.status_code and e.status_code == 403 and isinstance(e.body, dict):
73
+ raise NodeException(
74
+ message=e.body.get("detail", "Provider credentials is missing or unavailable"),
75
+ code=WorkflowErrorCode.PROVIDER_CREDENTIALS_UNAVAILABLE,
76
+ )
77
+
78
+ elif e.status_code and e.status_code >= 400 and e.status_code < 500 and isinstance(e.body, dict):
73
79
  raise NodeException(
74
80
  message=e.body.get("detail", "Failed to execute Prompt"),
75
81
  code=WorkflowErrorCode.INVALID_INPUTS,
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from uuid import UUID
3
- from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Sequence, Union
3
+ from typing import Any, ClassVar, Dict, Generator, Generic, Iterator, List, Optional, Sequence, Set, Union
4
4
 
5
5
  from vellum import (
6
6
  ChatHistoryInputRequest,
@@ -9,17 +9,20 @@ from vellum import (
9
9
  JsonInputRequest,
10
10
  PromptDeploymentExpandMetaRequest,
11
11
  PromptDeploymentInputRequest,
12
+ PromptOutput,
12
13
  RawPromptExecutionOverridesRequest,
13
14
  StringInputRequest,
14
15
  )
15
- from vellum.client import RequestOptions
16
+ from vellum.client import ApiError, RequestOptions
16
17
  from vellum.client.types.chat_message_request import ChatMessageRequest
17
18
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
18
19
  from vellum.workflows.context import get_execution_context
19
20
  from vellum.workflows.errors import WorkflowErrorCode
21
+ from vellum.workflows.errors.types import vellum_error_to_workflow_error
20
22
  from vellum.workflows.events.types import default_serializer
21
23
  from vellum.workflows.exceptions import NodeException
22
24
  from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePromptNode
25
+ from vellum.workflows.outputs import BaseOutput
23
26
  from vellum.workflows.types import MergeBehavior
24
27
  from vellum.workflows.types.generics import StateType
25
28
 
@@ -56,13 +59,21 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
56
59
  class Trigger(BasePromptNode.Trigger):
57
60
  merge_behavior = MergeBehavior.AWAIT_ANY
58
61
 
59
- def _get_prompt_event_stream(self) -> Iterator[ExecutePromptEvent]:
62
+ def _get_prompt_event_stream(self, ml_model_fallback: Optional[str] = None) -> Iterator[ExecutePromptEvent]:
60
63
  execution_context = get_execution_context()
61
64
  request_options = self.request_options or RequestOptions()
62
65
  request_options["additional_body_parameters"] = {
63
66
  "execution_context": execution_context.model_dump(mode="json"),
64
67
  **request_options.get("additional_body_parameters", {}),
65
68
  }
69
+ if ml_model_fallback:
70
+ request_options["additional_body_parameters"] = {
71
+ "overrides": {
72
+ "ml_model_fallback": ml_model_fallback,
73
+ },
74
+ **request_options.get("additional_body_parameters", {}),
75
+ }
76
+
66
77
  return self._context.vellum_client.execute_prompt_stream(
67
78
  inputs=self._compile_prompt_inputs(),
68
79
  prompt_deployment_id=str(self.deployment) if isinstance(self.deployment, UUID) else None,
@@ -76,6 +87,86 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
76
87
  request_options=request_options,
77
88
  )
78
89
 
90
+ def _process_prompt_event_stream(
91
+ self,
92
+ prompt_event_stream: Optional[Iterator[ExecutePromptEvent]] = None,
93
+ tried_fallbacks: Optional[set[str]] = None,
94
+ ) -> Generator[BaseOutput, None, Optional[List[PromptOutput]]]:
95
+ """Override the base prompt node _process_prompt_event_stream()"""
96
+ self._validate()
97
+
98
+ if tried_fallbacks is None:
99
+ tried_fallbacks = set()
100
+
101
+ if prompt_event_stream is None:
102
+ try:
103
+ prompt_event_stream = self._get_prompt_event_stream()
104
+ next(prompt_event_stream)
105
+ except ApiError as e:
106
+ if (
107
+ e.status_code
108
+ and e.status_code < 500
109
+ and self.ml_model_fallbacks is not OMIT
110
+ and self.ml_model_fallbacks is not None
111
+ ):
112
+ prompt_event_stream = self._retry_prompt_stream_with_fallbacks(tried_fallbacks)
113
+ else:
114
+ self._handle_api_error(e)
115
+
116
+ outputs: Optional[List[PromptOutput]] = None
117
+ if prompt_event_stream is not None:
118
+ for event in prompt_event_stream:
119
+ if event.state == "INITIATED":
120
+ continue
121
+ elif event.state == "STREAMING":
122
+ yield BaseOutput(name="results", delta=event.output.value)
123
+ elif event.state == "FULFILLED":
124
+ outputs = event.outputs
125
+ yield BaseOutput(name="results", value=event.outputs)
126
+ elif event.state == "REJECTED":
127
+ if (
128
+ event.error
129
+ and event.error.code == WorkflowErrorCode.PROVIDER_ERROR.value
130
+ and self.ml_model_fallbacks is not OMIT
131
+ and self.ml_model_fallbacks is not None
132
+ ):
133
+ try:
134
+ fallback_stream = self._retry_prompt_stream_with_fallbacks(tried_fallbacks)
135
+ fallback_outputs = yield from self._process_prompt_event_stream(
136
+ fallback_stream, tried_fallbacks
137
+ )
138
+ return fallback_outputs
139
+ except ApiError:
140
+ pass
141
+
142
+ workflow_error = vellum_error_to_workflow_error(event.error)
143
+ raise NodeException.of(workflow_error)
144
+
145
+ return outputs
146
+
147
+ def _retry_prompt_stream_with_fallbacks(self, tried_fallbacks: Set[str]) -> Optional[Iterator[ExecutePromptEvent]]:
148
+ if self.ml_model_fallbacks is not None:
149
+ for ml_model_fallback in self.ml_model_fallbacks:
150
+ if ml_model_fallback in tried_fallbacks:
151
+ continue
152
+
153
+ try:
154
+ tried_fallbacks.add(ml_model_fallback)
155
+ prompt_event_stream = self._get_prompt_event_stream(ml_model_fallback=ml_model_fallback)
156
+ next(prompt_event_stream)
157
+ return prompt_event_stream
158
+ except ApiError:
159
+ continue
160
+ else:
161
+ self._handle_api_error(
162
+ ApiError(
163
+ body={"detail": f"Failed to execute prompts with these fallbacks: {self.ml_model_fallbacks}"},
164
+ status_code=400,
165
+ )
166
+ )
167
+
168
+ return None
169
+
79
170
  def _compile_prompt_inputs(self) -> List[PromptDeploymentInputRequest]:
80
171
  # TODO: We may want to consolidate with subworkflow deployment input compilation
81
172
  # https://app.shortcut.com/vellum/story/4117
@@ -1,12 +1,8 @@
1
1
  import pytest
2
2
  from uuid import UUID
3
3
 
4
- from vellum.workflows.events.types import (
5
- CodeResourceDefinition,
6
- NodeParentContext,
7
- WorkflowDeploymentParentContext,
8
- WorkflowParentContext,
9
- )
4
+ from vellum.workflows.events.types import NodeParentContext, WorkflowDeploymentParentContext, WorkflowParentContext
5
+ from vellum.workflows.types.definition import CodeResourceDefinition
10
6
 
11
7
 
12
8
  @pytest.fixture
@@ -117,7 +117,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
117
117
  value=cast(Dict[str, Any], input_value),
118
118
  )
119
119
  )
120
- elif isinstance(input_value, float):
120
+ elif isinstance(input_value, (int, float)):
121
121
  compiled_inputs.append(
122
122
  NumberInput(
123
123
  name=input_name,