vellum-ai 0.14.35__py3-none-any.whl → 0.14.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. vellum/__init__.py +6 -0
  2. vellum/client/__init__.py +4 -4
  3. vellum/client/core/client_wrapper.py +1 -1
  4. vellum/client/resources/release_reviews/client.py +118 -1
  5. vellum/client/types/__init__.py +6 -0
  6. vellum/client/types/logical_operator.py +1 -0
  7. vellum/client/types/prompt_deployment_release.py +34 -0
  8. vellum/client/types/prompt_deployment_release_prompt_deployment.py +19 -0
  9. vellum/client/types/prompt_deployment_release_prompt_version.py +19 -0
  10. vellum/types/prompt_deployment_release.py +3 -0
  11. vellum/types/prompt_deployment_release_prompt_deployment.py +3 -0
  12. vellum/types/prompt_deployment_release_prompt_version.py +3 -0
  13. vellum/workflows/inputs/base.py +2 -1
  14. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +2 -0
  15. vellum/workflows/nodes/displayable/guardrail_node/node.py +35 -12
  16. vellum/workflows/nodes/displayable/guardrail_node/test_node.py +88 -0
  17. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +14 -2
  18. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +43 -0
  19. vellum/workflows/state/base.py +38 -3
  20. vellum/workflows/state/tests/test_state.py +49 -0
  21. vellum/workflows/workflows/base.py +17 -0
  22. vellum/workflows/workflows/tests/test_base_workflow.py +39 -0
  23. {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/METADATA +1 -1
  24. {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/RECORD +42 -35
  25. vellum_cli/pull.py +3 -0
  26. vellum_cli/tests/test_pull.py +3 -1
  27. vellum_ee/workflows/display/base.py +9 -7
  28. vellum_ee/workflows/display/nodes/__init__.py +2 -2
  29. vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -2
  30. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -0
  31. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +33 -0
  32. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +3 -4
  33. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -1
  34. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +0 -1
  35. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
  36. vellum_ee/workflows/display/types.py +6 -7
  37. vellum_ee/workflows/display/vellum.py +5 -4
  38. vellum_ee/workflows/display/workflows/base_workflow_display.py +20 -19
  39. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +11 -37
  40. {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/LICENSE +0 -0
  41. {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/WHEEL +0 -0
  42. {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/entry_points.txt +0 -0
vellum/__init__.py CHANGED
@@ -300,6 +300,9 @@ from .types import (
300
300
  PromptDeploymentExpandMetaRequest,
301
301
  PromptDeploymentInputRequest,
302
302
  PromptDeploymentParentContext,
303
+ PromptDeploymentRelease,
304
+ PromptDeploymentReleasePromptDeployment,
305
+ PromptDeploymentReleasePromptVersion,
303
306
  PromptExecConfig,
304
307
  PromptExecutionMeta,
305
308
  PromptNodeExecutionMeta,
@@ -922,6 +925,9 @@ __all__ = [
922
925
  "PromptDeploymentExpandMetaRequest",
923
926
  "PromptDeploymentInputRequest",
924
927
  "PromptDeploymentParentContext",
928
+ "PromptDeploymentRelease",
929
+ "PromptDeploymentReleasePromptDeployment",
930
+ "PromptDeploymentReleasePromptVersion",
925
931
  "PromptExecConfig",
926
932
  "PromptExecutionMeta",
927
933
  "PromptNodeExecutionMeta",
vellum/client/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .core.client_wrapper import SyncClientWrapper
7
7
  from .resources.ad_hoc.client import AdHocClient
8
8
  from .resources.container_images.client import ContainerImagesClient
9
9
  from .resources.deployments.client import DeploymentsClient
10
+ from .resources.release_reviews.client import ReleaseReviewsClient
10
11
  from .resources.document_indexes.client import DocumentIndexesClient
11
12
  from .resources.documents.client import DocumentsClient
12
13
  from .resources.folder_entities.client import FolderEntitiesClient
@@ -18,7 +19,6 @@ from .resources.sandboxes.client import SandboxesClient
18
19
  from .resources.test_suite_runs.client import TestSuiteRunsClient
19
20
  from .resources.test_suites.client import TestSuitesClient
20
21
  from .resources.workflow_deployments.client import WorkflowDeploymentsClient
21
- from .resources.release_reviews.client import ReleaseReviewsClient
22
22
  from .resources.workflow_sandboxes.client import WorkflowSandboxesClient
23
23
  from .resources.workflows.client import WorkflowsClient
24
24
  from .resources.workspace_secrets.client import WorkspaceSecretsClient
@@ -65,6 +65,7 @@ from .core.client_wrapper import AsyncClientWrapper
65
65
  from .resources.ad_hoc.client import AsyncAdHocClient
66
66
  from .resources.container_images.client import AsyncContainerImagesClient
67
67
  from .resources.deployments.client import AsyncDeploymentsClient
68
+ from .resources.release_reviews.client import AsyncReleaseReviewsClient
68
69
  from .resources.document_indexes.client import AsyncDocumentIndexesClient
69
70
  from .resources.documents.client import AsyncDocumentsClient
70
71
  from .resources.folder_entities.client import AsyncFolderEntitiesClient
@@ -76,7 +77,6 @@ from .resources.sandboxes.client import AsyncSandboxesClient
76
77
  from .resources.test_suite_runs.client import AsyncTestSuiteRunsClient
77
78
  from .resources.test_suites.client import AsyncTestSuitesClient
78
79
  from .resources.workflow_deployments.client import AsyncWorkflowDeploymentsClient
79
- from .resources.release_reviews.client import AsyncReleaseReviewsClient
80
80
  from .resources.workflow_sandboxes.client import AsyncWorkflowSandboxesClient
81
81
  from .resources.workflows.client import AsyncWorkflowsClient
82
82
  from .resources.workspace_secrets.client import AsyncWorkspaceSecretsClient
@@ -143,6 +143,7 @@ class Vellum:
143
143
  self.ad_hoc = AdHocClient(client_wrapper=self._client_wrapper)
144
144
  self.container_images = ContainerImagesClient(client_wrapper=self._client_wrapper)
145
145
  self.deployments = DeploymentsClient(client_wrapper=self._client_wrapper)
146
+ self.release_reviews = ReleaseReviewsClient(client_wrapper=self._client_wrapper)
146
147
  self.document_indexes = DocumentIndexesClient(client_wrapper=self._client_wrapper)
147
148
  self.documents = DocumentsClient(client_wrapper=self._client_wrapper)
148
149
  self.folder_entities = FolderEntitiesClient(client_wrapper=self._client_wrapper)
@@ -154,7 +155,6 @@ class Vellum:
154
155
  self.test_suite_runs = TestSuiteRunsClient(client_wrapper=self._client_wrapper)
155
156
  self.test_suites = TestSuitesClient(client_wrapper=self._client_wrapper)
156
157
  self.workflow_deployments = WorkflowDeploymentsClient(client_wrapper=self._client_wrapper)
157
- self.release_reviews = ReleaseReviewsClient(client_wrapper=self._client_wrapper)
158
158
  self.workflow_sandboxes = WorkflowSandboxesClient(client_wrapper=self._client_wrapper)
159
159
  self.workflows = WorkflowsClient(client_wrapper=self._client_wrapper)
160
160
  self.workspace_secrets = WorkspaceSecretsClient(client_wrapper=self._client_wrapper)
@@ -1486,6 +1486,7 @@ class AsyncVellum:
1486
1486
  self.ad_hoc = AsyncAdHocClient(client_wrapper=self._client_wrapper)
1487
1487
  self.container_images = AsyncContainerImagesClient(client_wrapper=self._client_wrapper)
1488
1488
  self.deployments = AsyncDeploymentsClient(client_wrapper=self._client_wrapper)
1489
+ self.release_reviews = AsyncReleaseReviewsClient(client_wrapper=self._client_wrapper)
1489
1490
  self.document_indexes = AsyncDocumentIndexesClient(client_wrapper=self._client_wrapper)
1490
1491
  self.documents = AsyncDocumentsClient(client_wrapper=self._client_wrapper)
1491
1492
  self.folder_entities = AsyncFolderEntitiesClient(client_wrapper=self._client_wrapper)
@@ -1497,7 +1498,6 @@ class AsyncVellum:
1497
1498
  self.test_suite_runs = AsyncTestSuiteRunsClient(client_wrapper=self._client_wrapper)
1498
1499
  self.test_suites = AsyncTestSuitesClient(client_wrapper=self._client_wrapper)
1499
1500
  self.workflow_deployments = AsyncWorkflowDeploymentsClient(client_wrapper=self._client_wrapper)
1500
- self.release_reviews = AsyncReleaseReviewsClient(client_wrapper=self._client_wrapper)
1501
1501
  self.workflow_sandboxes = AsyncWorkflowSandboxesClient(client_wrapper=self._client_wrapper)
1502
1502
  self.workflows = AsyncWorkflowsClient(client_wrapper=self._client_wrapper)
1503
1503
  self.workspace_secrets = AsyncWorkspaceSecretsClient(client_wrapper=self._client_wrapper)
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.35",
21
+ "X-Fern-SDK-Version": "0.14.37",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -3,11 +3,12 @@
3
3
  from ...core.client_wrapper import SyncClientWrapper
4
4
  import typing
5
5
  from ...core.request_options import RequestOptions
6
- from ...types.workflow_deployment_release import WorkflowDeploymentRelease
6
+ from ...types.prompt_deployment_release import PromptDeploymentRelease
7
7
  from ...core.jsonable_encoder import jsonable_encoder
8
8
  from ...core.pydantic_utilities import parse_obj_as
9
9
  from json.decoder import JSONDecodeError
10
10
  from ...core.api_error import ApiError
11
+ from ...types.workflow_deployment_release import WorkflowDeploymentRelease
11
12
  from ...core.client_wrapper import AsyncClientWrapper
12
13
 
13
14
 
@@ -15,6 +16,60 @@ class ReleaseReviewsClient:
15
16
  def __init__(self, *, client_wrapper: SyncClientWrapper):
16
17
  self._client_wrapper = client_wrapper
17
18
 
19
+ def retrieve_prompt_deployment_release(
20
+ self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
21
+ ) -> PromptDeploymentRelease:
22
+ """
23
+ Retrieve a specific Prompt Deployment Release by either its UUID or the name of a Release Tag that points to it.
24
+
25
+ Parameters
26
+ ----------
27
+ id : str
28
+ A UUID string identifying this deployment.
29
+
30
+ release_id_or_release_tag : str
31
+ Either the UUID of Prompt Deployment Release you'd like to retrieve, or the name of a Release Tag that's pointing to the Prompt Deployment Release you'd like to retrieve.
32
+
33
+ request_options : typing.Optional[RequestOptions]
34
+ Request-specific configuration.
35
+
36
+ Returns
37
+ -------
38
+ PromptDeploymentRelease
39
+
40
+
41
+ Examples
42
+ --------
43
+ from vellum import Vellum
44
+
45
+ client = Vellum(
46
+ api_key="YOUR_API_KEY",
47
+ )
48
+ client.release_reviews.retrieve_prompt_deployment_release(
49
+ id="id",
50
+ release_id_or_release_tag="release_id_or_release_tag",
51
+ )
52
+ """
53
+ _response = self._client_wrapper.httpx_client.request(
54
+ f"v1/deployments/{jsonable_encoder(id)}/releases/{jsonable_encoder(release_id_or_release_tag)}",
55
+ base_url=self._client_wrapper.get_environment().default,
56
+ method="GET",
57
+ request_options=request_options,
58
+ )
59
+ try:
60
+ if 200 <= _response.status_code < 300:
61
+ return typing.cast(
62
+ PromptDeploymentRelease,
63
+ parse_obj_as(
64
+ type_=PromptDeploymentRelease, # type: ignore
65
+ object_=_response.json(),
66
+ ),
67
+ )
68
+ _response_json = _response.json()
69
+ except JSONDecodeError:
70
+ raise ApiError(status_code=_response.status_code, body=_response.text)
71
+ raise ApiError(status_code=_response.status_code, body=_response_json)
72
+
18
73
  def retrieve_workflow_deployment_release(
19
74
  self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
20
75
  ) -> WorkflowDeploymentRelease:
@@ -74,6 +129,68 @@ class AsyncReleaseReviewsClient:
74
129
  def __init__(self, *, client_wrapper: AsyncClientWrapper):
75
130
  self._client_wrapper = client_wrapper
76
131
 
132
+ async def retrieve_prompt_deployment_release(
133
+ self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
134
+ ) -> PromptDeploymentRelease:
135
+ """
136
+ Retrieve a specific Prompt Deployment Release by either its UUID or the name of a Release Tag that points to it.
137
+
138
+ Parameters
139
+ ----------
140
+ id : str
141
+ A UUID string identifying this deployment.
142
+
143
+ release_id_or_release_tag : str
144
+ Either the UUID of Prompt Deployment Release you'd like to retrieve, or the name of a Release Tag that's pointing to the Prompt Deployment Release you'd like to retrieve.
145
+
146
+ request_options : typing.Optional[RequestOptions]
147
+ Request-specific configuration.
148
+
149
+ Returns
150
+ -------
151
+ PromptDeploymentRelease
152
+
153
+
154
+ Examples
155
+ --------
156
+ import asyncio
157
+
158
+ from vellum import AsyncVellum
159
+
160
+ client = AsyncVellum(
161
+ api_key="YOUR_API_KEY",
162
+ )
163
+
164
+
165
+ async def main() -> None:
166
+ await client.release_reviews.retrieve_prompt_deployment_release(
167
+ id="id",
168
+ release_id_or_release_tag="release_id_or_release_tag",
169
+ )
170
+
171
+
172
+ asyncio.run(main())
173
+ """
174
+ _response = await self._client_wrapper.httpx_client.request(
175
+ f"v1/deployments/{jsonable_encoder(id)}/releases/{jsonable_encoder(release_id_or_release_tag)}",
176
+ base_url=self._client_wrapper.get_environment().default,
177
+ method="GET",
178
+ request_options=request_options,
179
+ )
180
+ try:
181
+ if 200 <= _response.status_code < 300:
182
+ return typing.cast(
183
+ PromptDeploymentRelease,
184
+ parse_obj_as(
185
+ type_=PromptDeploymentRelease, # type: ignore
186
+ object_=_response.json(),
187
+ ),
188
+ )
189
+ _response_json = _response.json()
190
+ except JSONDecodeError:
191
+ raise ApiError(status_code=_response.status_code, body=_response.text)
192
+ raise ApiError(status_code=_response.status_code, body=_response_json)
193
+
77
194
  async def retrieve_workflow_deployment_release(
78
195
  self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
79
196
  ) -> WorkflowDeploymentRelease:
@@ -308,6 +308,9 @@ from .prompt_block_state import PromptBlockState
308
308
  from .prompt_deployment_expand_meta_request import PromptDeploymentExpandMetaRequest
309
309
  from .prompt_deployment_input_request import PromptDeploymentInputRequest
310
310
  from .prompt_deployment_parent_context import PromptDeploymentParentContext
311
+ from .prompt_deployment_release import PromptDeploymentRelease
312
+ from .prompt_deployment_release_prompt_deployment import PromptDeploymentReleasePromptDeployment
313
+ from .prompt_deployment_release_prompt_version import PromptDeploymentReleasePromptVersion
311
314
  from .prompt_exec_config import PromptExecConfig
312
315
  from .prompt_execution_meta import PromptExecutionMeta
313
316
  from .prompt_node_execution_meta import PromptNodeExecutionMeta
@@ -902,6 +905,9 @@ __all__ = [
902
905
  "PromptDeploymentExpandMetaRequest",
903
906
  "PromptDeploymentInputRequest",
904
907
  "PromptDeploymentParentContext",
908
+ "PromptDeploymentRelease",
909
+ "PromptDeploymentReleasePromptDeployment",
910
+ "PromptDeploymentReleasePromptVersion",
905
911
  "PromptExecConfig",
906
912
  "PromptExecutionMeta",
907
913
  "PromptNodeExecutionMeta",
@@ -29,6 +29,7 @@ LogicalOperator = typing.Union[
29
29
  "parseJson",
30
30
  "and",
31
31
  "or",
32
+ "isError",
32
33
  ],
33
34
  typing.Any,
34
35
  ]
@@ -0,0 +1,34 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import datetime as dt
5
+ from .release_environment import ReleaseEnvironment
6
+ import typing
7
+ from .release_created_by import ReleaseCreatedBy
8
+ from .prompt_deployment_release_prompt_version import PromptDeploymentReleasePromptVersion
9
+ from .prompt_deployment_release_prompt_deployment import PromptDeploymentReleasePromptDeployment
10
+ from .release_release_tag import ReleaseReleaseTag
11
+ from .slim_release_review import SlimReleaseReview
12
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
13
+ import pydantic
14
+
15
+
16
+ class PromptDeploymentRelease(UniversalBaseModel):
17
+ id: str
18
+ created: dt.datetime
19
+ environment: ReleaseEnvironment
20
+ created_by: typing.Optional[ReleaseCreatedBy] = None
21
+ prompt_version: PromptDeploymentReleasePromptVersion
22
+ deployment: PromptDeploymentReleasePromptDeployment
23
+ description: typing.Optional[str] = None
24
+ release_tags: typing.List[ReleaseReleaseTag]
25
+ reviews: typing.List[SlimReleaseReview]
26
+
27
+ if IS_PYDANTIC_V2:
28
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
29
+ else:
30
+
31
+ class Config:
32
+ frozen = True
33
+ smart_union = True
34
+ extra = pydantic.Extra.allow
@@ -0,0 +1,19 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
5
+ import typing
6
+ import pydantic
7
+
8
+
9
+ class PromptDeploymentReleasePromptDeployment(UniversalBaseModel):
10
+ name: str
11
+
12
+ if IS_PYDANTIC_V2:
13
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
14
+ else:
15
+
16
+ class Config:
17
+ frozen = True
18
+ smart_union = True
19
+ extra = pydantic.Extra.allow
@@ -0,0 +1,19 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
5
+ import typing
6
+ import pydantic
7
+
8
+
9
+ class PromptDeploymentReleasePromptVersion(UniversalBaseModel):
10
+ id: str
11
+
12
+ if IS_PYDANTIC_V2:
13
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
14
+ else:
15
+
16
+ class Config:
17
+ frozen = True
18
+ smart_union = True
19
+ extra = pydantic.Extra.allow
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.prompt_deployment_release import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.prompt_deployment_release_prompt_deployment import *
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.prompt_deployment_release_prompt_version import *
@@ -4,6 +4,7 @@ from typing_extensions import dataclass_transform
4
4
  from pydantic import GetCoreSchemaHandler
5
5
  from pydantic_core import core_schema
6
6
 
7
+ from vellum.workflows.constants import undefined
7
8
  from vellum.workflows.errors.types import WorkflowErrorCode
8
9
  from vellum.workflows.exceptions import WorkflowInitializationException
9
10
  from vellum.workflows.references import ExternalInputReference, WorkflowInputReference
@@ -15,7 +16,7 @@ from vellum.workflows.types.utils import get_class_attr_names, infer_types
15
16
  class _BaseInputsMeta(type):
16
17
  def __getattribute__(cls, name: str) -> Any:
17
18
  if not name.startswith("_") and name in cls.__annotations__ and issubclass(cls, BaseInputs):
18
- instance = vars(cls).get(name)
19
+ instance = vars(cls).get(name, undefined)
19
20
  types = infer_types(cls, name)
20
21
 
21
22
  if getattr(cls, "__descriptor_class__", None) is ExternalInputReference:
@@ -38,6 +38,7 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
38
38
  expand_raw: Optional[Sequence[str]] - Expandable raw fields to include in the response
39
39
  metadata: Optional[Dict[str, Optional[Any]]] - The metadata to use for the Prompt Execution
40
40
  request_options: Optional[RequestOptions] - The request options to use for the Prompt Execution
41
+ ml_model_fallback: Optional[Sequence[str]] - ML model fallbacks to use
41
42
  """
42
43
 
43
44
  # Either the Prompt Deployment's UUID or its name.
@@ -50,6 +51,7 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
50
51
  raw_overrides: Optional[RawPromptExecutionOverridesRequest] = OMIT
51
52
  expand_raw: Optional[Sequence[str]] = OMIT
52
53
  metadata: Optional[Dict[str, Optional[Any]]] = OMIT
54
+ ml_model_fallbacks: Optional[Sequence[str]] = OMIT
53
55
 
54
56
  class Trigger(BasePromptNode.Trigger):
55
57
  merge_behavior = MergeBehavior.AWAIT_ANY
@@ -2,6 +2,7 @@ from uuid import UUID
2
2
  from typing import Any, ClassVar, Dict, Generic, List, Optional, Union, cast
3
3
 
4
4
  from vellum import ChatHistoryInput, ChatMessage, JsonInput, MetricDefinitionInput, NumberInput, StringInput
5
+ from vellum.client import ApiError
5
6
  from vellum.core import RequestOptions
6
7
  from vellum.workflows.constants import LATEST_RELEASE_TAG
7
8
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -34,28 +35,50 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
34
35
 
35
36
  class Outputs(BaseOutputs):
36
37
  score: float
38
+ normalized_score: Optional[float]
37
39
  log: Optional[str]
38
40
 
39
41
  def run(self) -> Outputs:
40
- metric_execution = self._context.vellum_client.metric_definitions.execute_metric_definition(
41
- self.metric_definition if isinstance(self.metric_definition, str) else str(self.metric_definition),
42
- inputs=self._compile_metric_inputs(),
43
- release_tag=self.release_tag,
44
- request_options=self.request_options,
45
- )
42
+ try:
43
+ metric_execution = self._context.vellum_client.metric_definitions.execute_metric_definition(
44
+ self.metric_definition if isinstance(self.metric_definition, str) else str(self.metric_definition),
45
+ inputs=self._compile_metric_inputs(),
46
+ release_tag=self.release_tag,
47
+ request_options=self.request_options,
48
+ )
49
+
50
+ except ApiError:
51
+ raise NodeException(
52
+ code=WorkflowErrorCode.NODE_EXECUTION,
53
+ message="Failed to execute metric definition",
54
+ )
46
55
 
47
56
  metric_outputs = {output.name: output.value for output in metric_execution.outputs}
48
57
 
49
- score = metric_outputs.get("score")
58
+ SCORE_KEY = "score"
59
+ NORMALIZED_SCORE_KEY = "normalized_score"
60
+ LOG_KEY = "log"
61
+
62
+ score = metric_outputs.get(SCORE_KEY)
50
63
  if not isinstance(score, float):
51
64
  raise NodeException(
52
- message="Metric execution must have one output named 'score' with type 'float'",
65
+ message=f"Metric execution must have one output named '{SCORE_KEY}' with type 'float'",
53
66
  code=WorkflowErrorCode.INVALID_OUTPUTS,
54
67
  )
55
- metric_outputs.pop("score")
68
+ metric_outputs.pop(SCORE_KEY)
69
+
70
+ if NORMALIZED_SCORE_KEY in metric_outputs:
71
+ normalized_score = metric_outputs.pop(NORMALIZED_SCORE_KEY)
72
+ if not isinstance(normalized_score, float):
73
+ raise NodeException(
74
+ message=f"Metric execution must have one output named '{NORMALIZED_SCORE_KEY}' with type 'float'",
75
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
76
+ )
77
+ else:
78
+ normalized_score = None
56
79
 
57
- if "log" in metric_outputs:
58
- log = metric_outputs.pop("log") or ""
80
+ if LOG_KEY in metric_outputs:
81
+ log = metric_outputs.pop(LOG_KEY) or ""
59
82
  if not isinstance(log, str):
60
83
  raise NodeException(
61
84
  message="Metric execution log output must be of type 'str'",
@@ -64,7 +87,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
64
87
  else:
65
88
  log = None
66
89
 
67
- return self.Outputs(score=score, log=log, **metric_outputs)
90
+ return self.Outputs(score=score, normalized_score=normalized_score, log=log, **metric_outputs)
68
91
 
69
92
  def _compile_metric_inputs(self) -> List[MetricDefinitionInput]:
70
93
  # TODO: We may want to consolidate with prompt deployment input compilation
@@ -1,8 +1,11 @@
1
1
  import pytest
2
2
 
3
3
  from vellum import TestSuiteRunMetricNumberOutput
4
+ from vellum.client import ApiError
4
5
  from vellum.client.types.metric_definition_execution import MetricDefinitionExecution
5
6
  from vellum.client.types.test_suite_run_metric_string_output import TestSuiteRunMetricStringOutput
7
+ from vellum.workflows.errors import WorkflowErrorCode
8
+ from vellum.workflows.exceptions import NodeException
6
9
  from vellum.workflows.nodes.displayable.guardrail_node.node import GuardrailNode
7
10
 
8
11
 
@@ -36,3 +39,88 @@ def test_run_guardrail_node__empty_log(vellum_client, log_value):
36
39
  # THEN the workflow should have completed successfully
37
40
  assert outputs.score == 0.6
38
41
  assert outputs.log == ""
42
+
43
+
44
+ def test_run_guardrail_node__normalized_score(vellum_client):
45
+ """Confirm that we can successfully invoke a Guardrail Node"""
46
+
47
+ # GIVEN a Guardrail Node
48
+ class MyGuard(GuardrailNode):
49
+ metric_definition = "example_metric_definition"
50
+ metric_inputs = {}
51
+
52
+ # AND we know that the guardrail node will return a normalized score
53
+ mock_metric_execution = MetricDefinitionExecution(
54
+ outputs=[
55
+ TestSuiteRunMetricNumberOutput(
56
+ name="score",
57
+ value=0.6,
58
+ ),
59
+ TestSuiteRunMetricNumberOutput(
60
+ name="normalized_score",
61
+ value=1.0,
62
+ ),
63
+ ],
64
+ )
65
+ vellum_client.metric_definitions.execute_metric_definition.return_value = mock_metric_execution
66
+
67
+ # WHEN we run the Guardrail Node
68
+ outputs = MyGuard().run()
69
+
70
+ # THEN the workflow should have completed successfully
71
+ assert outputs.score == 0.6
72
+ assert outputs.normalized_score == 1.0
73
+
74
+
75
+ def test_run_guardrail_node__normalized_score_null(vellum_client):
76
+ # GIVEN a Guardrail Node
77
+ class MyGuard(GuardrailNode):
78
+ metric_definition = "example_metric_definition"
79
+ metric_inputs = {}
80
+
81
+ # AND we know that the guardrail node will return a normalized score that is None
82
+ mock_metric_execution = MetricDefinitionExecution(
83
+ outputs=[
84
+ TestSuiteRunMetricNumberOutput(
85
+ name="score",
86
+ value=0.6,
87
+ ),
88
+ TestSuiteRunMetricNumberOutput(
89
+ name="normalized_score",
90
+ value=None,
91
+ ),
92
+ ],
93
+ )
94
+ vellum_client.metric_definitions.execute_metric_definition.return_value = mock_metric_execution
95
+
96
+ # WHEN we run the Guardrail Node
97
+ with pytest.raises(NodeException) as exc_info:
98
+ MyGuard().run()
99
+
100
+ # THEN we get an exception
101
+ assert exc_info.value.message == "Metric execution must have one output named 'normalized_score' with type 'float'"
102
+ assert exc_info.value.code == WorkflowErrorCode.INVALID_OUTPUTS
103
+
104
+
105
+ def test_run_guardrail_node__api_error(vellum_client):
106
+ # GIVEN a Guardrail Node
107
+ class MyGuard(GuardrailNode):
108
+ metric_definition = "example_metric_definition"
109
+ metric_inputs = {}
110
+
111
+ # AND the API client raises an ApiError when called
112
+ api_error = ApiError(status_code=503)
113
+ vellum_client.metric_definitions.execute_metric_definition.side_effect = api_error
114
+
115
+ # WHEN we run the Guardrail Node
116
+ with pytest.raises(NodeException) as exc_info:
117
+ MyGuard().run()
118
+
119
+ # THEN we get a NodeException with the appropriate error code
120
+ assert exc_info.value.code == WorkflowErrorCode.NODE_EXECUTION
121
+ assert "Failed to execute metric definition" in exc_info.value.message
122
+
123
+ # Verify the mock was called with the expected arguments
124
+ vellum_client.metric_definitions.execute_metric_definition.assert_called_once_with(
125
+ "example_metric_definition", inputs=[], release_tag="LATEST", request_options=None
126
+ )
@@ -1,6 +1,7 @@
1
1
  import json
2
- from typing import Iterator
2
+ from typing import Any, Dict, Iterator, Type, Union
3
3
 
4
+ from vellum.workflows.constants import undefined
4
5
  from vellum.workflows.errors import WorkflowErrorCode
5
6
  from vellum.workflows.exceptions import NodeException
6
7
  from vellum.workflows.nodes.displayable.bases import BasePromptDeploymentNode as BasePromptDeploymentNode
@@ -11,7 +12,7 @@ from vellum.workflows.types.generics import StateType
11
12
 
12
13
  class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
13
14
  """
14
- Used to execute a Prompt Deployment and surface a string output for convenience.
15
+ Used to execute a Prompt Deployment and surface a string output and json output if applicable for convenience.
15
16
 
16
17
  prompt_inputs: EntityInputsInterface - The inputs for the Prompt
17
18
  deployment: Union[UUID, str] - Either the Prompt Deployment's UUID or its name.
@@ -33,9 +34,11 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
33
34
  The outputs of the PromptDeploymentNode.
34
35
 
35
36
  text: str - The result of the Prompt Execution
37
+ json: Optional[Dict[Any, Any]] - The result of the Prompt Execution in JSON format
36
38
  """
37
39
 
38
40
  text: str
41
+ json: Union[Dict[Any, Any], Type[undefined]] = undefined
39
42
 
40
43
  def run(self) -> Iterator[BaseOutput]:
41
44
  outputs = yield from self._process_prompt_event_stream()
@@ -46,12 +49,18 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
46
49
  )
47
50
 
48
51
  string_outputs = []
52
+ json_output = None
53
+
49
54
  for output in outputs:
50
55
  if output.value is None:
51
56
  continue
52
57
 
53
58
  if output.type == "STRING":
54
59
  string_outputs.append(output.value)
60
+ try:
61
+ json_output = json.loads(output.value)
62
+ except (json.JSONDecodeError, TypeError):
63
+ pass
55
64
  elif output.type == "JSON":
56
65
  string_outputs.append(json.dumps(output.value, indent=4))
57
66
  elif output.type == "FUNCTION_CALL":
@@ -61,3 +70,6 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
61
70
 
62
71
  value = "\n".join(string_outputs)
63
72
  yield BaseOutput(name="text", value=value)
73
+
74
+ if json_output:
75
+ yield BaseOutput(name="json", value=json_output)