vellum-ai 0.14.35__py3-none-any.whl → 0.14.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +6 -0
- vellum/client/__init__.py +4 -4
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/resources/release_reviews/client.py +118 -1
- vellum/client/types/__init__.py +6 -0
- vellum/client/types/logical_operator.py +1 -0
- vellum/client/types/prompt_deployment_release.py +34 -0
- vellum/client/types/prompt_deployment_release_prompt_deployment.py +19 -0
- vellum/client/types/prompt_deployment_release_prompt_version.py +19 -0
- vellum/types/prompt_deployment_release.py +3 -0
- vellum/types/prompt_deployment_release_prompt_deployment.py +3 -0
- vellum/types/prompt_deployment_release_prompt_version.py +3 -0
- vellum/workflows/inputs/base.py +2 -1
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +2 -0
- vellum/workflows/nodes/displayable/guardrail_node/node.py +35 -12
- vellum/workflows/nodes/displayable/guardrail_node/test_node.py +88 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +14 -2
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +43 -0
- vellum/workflows/state/base.py +38 -3
- vellum/workflows/state/tests/test_state.py +49 -0
- vellum/workflows/workflows/base.py +17 -0
- vellum/workflows/workflows/tests/test_base_workflow.py +39 -0
- {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/RECORD +42 -35
- vellum_cli/pull.py +3 -0
- vellum_cli/tests/test_pull.py +3 -1
- vellum_ee/workflows/display/base.py +9 -7
- vellum_ee/workflows/display/nodes/__init__.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -0
- vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +33 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +3 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/types.py +6 -7
- vellum_ee/workflows/display/vellum.py +5 -4
- vellum_ee/workflows/display/workflows/base_workflow_display.py +20 -19
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +11 -37
- {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.35.dist-info → vellum_ai-0.14.37.dist-info}/entry_points.txt +0 -0
vellum/__init__.py
CHANGED
@@ -300,6 +300,9 @@ from .types import (
|
|
300
300
|
PromptDeploymentExpandMetaRequest,
|
301
301
|
PromptDeploymentInputRequest,
|
302
302
|
PromptDeploymentParentContext,
|
303
|
+
PromptDeploymentRelease,
|
304
|
+
PromptDeploymentReleasePromptDeployment,
|
305
|
+
PromptDeploymentReleasePromptVersion,
|
303
306
|
PromptExecConfig,
|
304
307
|
PromptExecutionMeta,
|
305
308
|
PromptNodeExecutionMeta,
|
@@ -922,6 +925,9 @@ __all__ = [
|
|
922
925
|
"PromptDeploymentExpandMetaRequest",
|
923
926
|
"PromptDeploymentInputRequest",
|
924
927
|
"PromptDeploymentParentContext",
|
928
|
+
"PromptDeploymentRelease",
|
929
|
+
"PromptDeploymentReleasePromptDeployment",
|
930
|
+
"PromptDeploymentReleasePromptVersion",
|
925
931
|
"PromptExecConfig",
|
926
932
|
"PromptExecutionMeta",
|
927
933
|
"PromptNodeExecutionMeta",
|
vellum/client/__init__.py
CHANGED
@@ -7,6 +7,7 @@ from .core.client_wrapper import SyncClientWrapper
|
|
7
7
|
from .resources.ad_hoc.client import AdHocClient
|
8
8
|
from .resources.container_images.client import ContainerImagesClient
|
9
9
|
from .resources.deployments.client import DeploymentsClient
|
10
|
+
from .resources.release_reviews.client import ReleaseReviewsClient
|
10
11
|
from .resources.document_indexes.client import DocumentIndexesClient
|
11
12
|
from .resources.documents.client import DocumentsClient
|
12
13
|
from .resources.folder_entities.client import FolderEntitiesClient
|
@@ -18,7 +19,6 @@ from .resources.sandboxes.client import SandboxesClient
|
|
18
19
|
from .resources.test_suite_runs.client import TestSuiteRunsClient
|
19
20
|
from .resources.test_suites.client import TestSuitesClient
|
20
21
|
from .resources.workflow_deployments.client import WorkflowDeploymentsClient
|
21
|
-
from .resources.release_reviews.client import ReleaseReviewsClient
|
22
22
|
from .resources.workflow_sandboxes.client import WorkflowSandboxesClient
|
23
23
|
from .resources.workflows.client import WorkflowsClient
|
24
24
|
from .resources.workspace_secrets.client import WorkspaceSecretsClient
|
@@ -65,6 +65,7 @@ from .core.client_wrapper import AsyncClientWrapper
|
|
65
65
|
from .resources.ad_hoc.client import AsyncAdHocClient
|
66
66
|
from .resources.container_images.client import AsyncContainerImagesClient
|
67
67
|
from .resources.deployments.client import AsyncDeploymentsClient
|
68
|
+
from .resources.release_reviews.client import AsyncReleaseReviewsClient
|
68
69
|
from .resources.document_indexes.client import AsyncDocumentIndexesClient
|
69
70
|
from .resources.documents.client import AsyncDocumentsClient
|
70
71
|
from .resources.folder_entities.client import AsyncFolderEntitiesClient
|
@@ -76,7 +77,6 @@ from .resources.sandboxes.client import AsyncSandboxesClient
|
|
76
77
|
from .resources.test_suite_runs.client import AsyncTestSuiteRunsClient
|
77
78
|
from .resources.test_suites.client import AsyncTestSuitesClient
|
78
79
|
from .resources.workflow_deployments.client import AsyncWorkflowDeploymentsClient
|
79
|
-
from .resources.release_reviews.client import AsyncReleaseReviewsClient
|
80
80
|
from .resources.workflow_sandboxes.client import AsyncWorkflowSandboxesClient
|
81
81
|
from .resources.workflows.client import AsyncWorkflowsClient
|
82
82
|
from .resources.workspace_secrets.client import AsyncWorkspaceSecretsClient
|
@@ -143,6 +143,7 @@ class Vellum:
|
|
143
143
|
self.ad_hoc = AdHocClient(client_wrapper=self._client_wrapper)
|
144
144
|
self.container_images = ContainerImagesClient(client_wrapper=self._client_wrapper)
|
145
145
|
self.deployments = DeploymentsClient(client_wrapper=self._client_wrapper)
|
146
|
+
self.release_reviews = ReleaseReviewsClient(client_wrapper=self._client_wrapper)
|
146
147
|
self.document_indexes = DocumentIndexesClient(client_wrapper=self._client_wrapper)
|
147
148
|
self.documents = DocumentsClient(client_wrapper=self._client_wrapper)
|
148
149
|
self.folder_entities = FolderEntitiesClient(client_wrapper=self._client_wrapper)
|
@@ -154,7 +155,6 @@ class Vellum:
|
|
154
155
|
self.test_suite_runs = TestSuiteRunsClient(client_wrapper=self._client_wrapper)
|
155
156
|
self.test_suites = TestSuitesClient(client_wrapper=self._client_wrapper)
|
156
157
|
self.workflow_deployments = WorkflowDeploymentsClient(client_wrapper=self._client_wrapper)
|
157
|
-
self.release_reviews = ReleaseReviewsClient(client_wrapper=self._client_wrapper)
|
158
158
|
self.workflow_sandboxes = WorkflowSandboxesClient(client_wrapper=self._client_wrapper)
|
159
159
|
self.workflows = WorkflowsClient(client_wrapper=self._client_wrapper)
|
160
160
|
self.workspace_secrets = WorkspaceSecretsClient(client_wrapper=self._client_wrapper)
|
@@ -1486,6 +1486,7 @@ class AsyncVellum:
|
|
1486
1486
|
self.ad_hoc = AsyncAdHocClient(client_wrapper=self._client_wrapper)
|
1487
1487
|
self.container_images = AsyncContainerImagesClient(client_wrapper=self._client_wrapper)
|
1488
1488
|
self.deployments = AsyncDeploymentsClient(client_wrapper=self._client_wrapper)
|
1489
|
+
self.release_reviews = AsyncReleaseReviewsClient(client_wrapper=self._client_wrapper)
|
1489
1490
|
self.document_indexes = AsyncDocumentIndexesClient(client_wrapper=self._client_wrapper)
|
1490
1491
|
self.documents = AsyncDocumentsClient(client_wrapper=self._client_wrapper)
|
1491
1492
|
self.folder_entities = AsyncFolderEntitiesClient(client_wrapper=self._client_wrapper)
|
@@ -1497,7 +1498,6 @@ class AsyncVellum:
|
|
1497
1498
|
self.test_suite_runs = AsyncTestSuiteRunsClient(client_wrapper=self._client_wrapper)
|
1498
1499
|
self.test_suites = AsyncTestSuitesClient(client_wrapper=self._client_wrapper)
|
1499
1500
|
self.workflow_deployments = AsyncWorkflowDeploymentsClient(client_wrapper=self._client_wrapper)
|
1500
|
-
self.release_reviews = AsyncReleaseReviewsClient(client_wrapper=self._client_wrapper)
|
1501
1501
|
self.workflow_sandboxes = AsyncWorkflowSandboxesClient(client_wrapper=self._client_wrapper)
|
1502
1502
|
self.workflows = AsyncWorkflowsClient(client_wrapper=self._client_wrapper)
|
1503
1503
|
self.workspace_secrets = AsyncWorkspaceSecretsClient(client_wrapper=self._client_wrapper)
|
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.37",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -3,11 +3,12 @@
|
|
3
3
|
from ...core.client_wrapper import SyncClientWrapper
|
4
4
|
import typing
|
5
5
|
from ...core.request_options import RequestOptions
|
6
|
-
from ...types.
|
6
|
+
from ...types.prompt_deployment_release import PromptDeploymentRelease
|
7
7
|
from ...core.jsonable_encoder import jsonable_encoder
|
8
8
|
from ...core.pydantic_utilities import parse_obj_as
|
9
9
|
from json.decoder import JSONDecodeError
|
10
10
|
from ...core.api_error import ApiError
|
11
|
+
from ...types.workflow_deployment_release import WorkflowDeploymentRelease
|
11
12
|
from ...core.client_wrapper import AsyncClientWrapper
|
12
13
|
|
13
14
|
|
@@ -15,6 +16,60 @@ class ReleaseReviewsClient:
|
|
15
16
|
def __init__(self, *, client_wrapper: SyncClientWrapper):
|
16
17
|
self._client_wrapper = client_wrapper
|
17
18
|
|
19
|
+
def retrieve_prompt_deployment_release(
|
20
|
+
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
21
|
+
) -> PromptDeploymentRelease:
|
22
|
+
"""
|
23
|
+
Retrieve a specific Prompt Deployment Release by either its UUID or the name of a Release Tag that points to it.
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
id : str
|
28
|
+
A UUID string identifying this deployment.
|
29
|
+
|
30
|
+
release_id_or_release_tag : str
|
31
|
+
Either the UUID of Prompt Deployment Release you'd like to retrieve, or the name of a Release Tag that's pointing to the Prompt Deployment Release you'd like to retrieve.
|
32
|
+
|
33
|
+
request_options : typing.Optional[RequestOptions]
|
34
|
+
Request-specific configuration.
|
35
|
+
|
36
|
+
Returns
|
37
|
+
-------
|
38
|
+
PromptDeploymentRelease
|
39
|
+
|
40
|
+
|
41
|
+
Examples
|
42
|
+
--------
|
43
|
+
from vellum import Vellum
|
44
|
+
|
45
|
+
client = Vellum(
|
46
|
+
api_key="YOUR_API_KEY",
|
47
|
+
)
|
48
|
+
client.release_reviews.retrieve_prompt_deployment_release(
|
49
|
+
id="id",
|
50
|
+
release_id_or_release_tag="release_id_or_release_tag",
|
51
|
+
)
|
52
|
+
"""
|
53
|
+
_response = self._client_wrapper.httpx_client.request(
|
54
|
+
f"v1/deployments/{jsonable_encoder(id)}/releases/{jsonable_encoder(release_id_or_release_tag)}",
|
55
|
+
base_url=self._client_wrapper.get_environment().default,
|
56
|
+
method="GET",
|
57
|
+
request_options=request_options,
|
58
|
+
)
|
59
|
+
try:
|
60
|
+
if 200 <= _response.status_code < 300:
|
61
|
+
return typing.cast(
|
62
|
+
PromptDeploymentRelease,
|
63
|
+
parse_obj_as(
|
64
|
+
type_=PromptDeploymentRelease, # type: ignore
|
65
|
+
object_=_response.json(),
|
66
|
+
),
|
67
|
+
)
|
68
|
+
_response_json = _response.json()
|
69
|
+
except JSONDecodeError:
|
70
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
71
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
72
|
+
|
18
73
|
def retrieve_workflow_deployment_release(
|
19
74
|
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
20
75
|
) -> WorkflowDeploymentRelease:
|
@@ -74,6 +129,68 @@ class AsyncReleaseReviewsClient:
|
|
74
129
|
def __init__(self, *, client_wrapper: AsyncClientWrapper):
|
75
130
|
self._client_wrapper = client_wrapper
|
76
131
|
|
132
|
+
async def retrieve_prompt_deployment_release(
|
133
|
+
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
134
|
+
) -> PromptDeploymentRelease:
|
135
|
+
"""
|
136
|
+
Retrieve a specific Prompt Deployment Release by either its UUID or the name of a Release Tag that points to it.
|
137
|
+
|
138
|
+
Parameters
|
139
|
+
----------
|
140
|
+
id : str
|
141
|
+
A UUID string identifying this deployment.
|
142
|
+
|
143
|
+
release_id_or_release_tag : str
|
144
|
+
Either the UUID of Prompt Deployment Release you'd like to retrieve, or the name of a Release Tag that's pointing to the Prompt Deployment Release you'd like to retrieve.
|
145
|
+
|
146
|
+
request_options : typing.Optional[RequestOptions]
|
147
|
+
Request-specific configuration.
|
148
|
+
|
149
|
+
Returns
|
150
|
+
-------
|
151
|
+
PromptDeploymentRelease
|
152
|
+
|
153
|
+
|
154
|
+
Examples
|
155
|
+
--------
|
156
|
+
import asyncio
|
157
|
+
|
158
|
+
from vellum import AsyncVellum
|
159
|
+
|
160
|
+
client = AsyncVellum(
|
161
|
+
api_key="YOUR_API_KEY",
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
async def main() -> None:
|
166
|
+
await client.release_reviews.retrieve_prompt_deployment_release(
|
167
|
+
id="id",
|
168
|
+
release_id_or_release_tag="release_id_or_release_tag",
|
169
|
+
)
|
170
|
+
|
171
|
+
|
172
|
+
asyncio.run(main())
|
173
|
+
"""
|
174
|
+
_response = await self._client_wrapper.httpx_client.request(
|
175
|
+
f"v1/deployments/{jsonable_encoder(id)}/releases/{jsonable_encoder(release_id_or_release_tag)}",
|
176
|
+
base_url=self._client_wrapper.get_environment().default,
|
177
|
+
method="GET",
|
178
|
+
request_options=request_options,
|
179
|
+
)
|
180
|
+
try:
|
181
|
+
if 200 <= _response.status_code < 300:
|
182
|
+
return typing.cast(
|
183
|
+
PromptDeploymentRelease,
|
184
|
+
parse_obj_as(
|
185
|
+
type_=PromptDeploymentRelease, # type: ignore
|
186
|
+
object_=_response.json(),
|
187
|
+
),
|
188
|
+
)
|
189
|
+
_response_json = _response.json()
|
190
|
+
except JSONDecodeError:
|
191
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
192
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
193
|
+
|
77
194
|
async def retrieve_workflow_deployment_release(
|
78
195
|
self, id: str, release_id_or_release_tag: str, *, request_options: typing.Optional[RequestOptions] = None
|
79
196
|
) -> WorkflowDeploymentRelease:
|
vellum/client/types/__init__.py
CHANGED
@@ -308,6 +308,9 @@ from .prompt_block_state import PromptBlockState
|
|
308
308
|
from .prompt_deployment_expand_meta_request import PromptDeploymentExpandMetaRequest
|
309
309
|
from .prompt_deployment_input_request import PromptDeploymentInputRequest
|
310
310
|
from .prompt_deployment_parent_context import PromptDeploymentParentContext
|
311
|
+
from .prompt_deployment_release import PromptDeploymentRelease
|
312
|
+
from .prompt_deployment_release_prompt_deployment import PromptDeploymentReleasePromptDeployment
|
313
|
+
from .prompt_deployment_release_prompt_version import PromptDeploymentReleasePromptVersion
|
311
314
|
from .prompt_exec_config import PromptExecConfig
|
312
315
|
from .prompt_execution_meta import PromptExecutionMeta
|
313
316
|
from .prompt_node_execution_meta import PromptNodeExecutionMeta
|
@@ -902,6 +905,9 @@ __all__ = [
|
|
902
905
|
"PromptDeploymentExpandMetaRequest",
|
903
906
|
"PromptDeploymentInputRequest",
|
904
907
|
"PromptDeploymentParentContext",
|
908
|
+
"PromptDeploymentRelease",
|
909
|
+
"PromptDeploymentReleasePromptDeployment",
|
910
|
+
"PromptDeploymentReleasePromptVersion",
|
905
911
|
"PromptExecConfig",
|
906
912
|
"PromptExecutionMeta",
|
907
913
|
"PromptNodeExecutionMeta",
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import datetime as dt
|
5
|
+
from .release_environment import ReleaseEnvironment
|
6
|
+
import typing
|
7
|
+
from .release_created_by import ReleaseCreatedBy
|
8
|
+
from .prompt_deployment_release_prompt_version import PromptDeploymentReleasePromptVersion
|
9
|
+
from .prompt_deployment_release_prompt_deployment import PromptDeploymentReleasePromptDeployment
|
10
|
+
from .release_release_tag import ReleaseReleaseTag
|
11
|
+
from .slim_release_review import SlimReleaseReview
|
12
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
13
|
+
import pydantic
|
14
|
+
|
15
|
+
|
16
|
+
class PromptDeploymentRelease(UniversalBaseModel):
|
17
|
+
id: str
|
18
|
+
created: dt.datetime
|
19
|
+
environment: ReleaseEnvironment
|
20
|
+
created_by: typing.Optional[ReleaseCreatedBy] = None
|
21
|
+
prompt_version: PromptDeploymentReleasePromptVersion
|
22
|
+
deployment: PromptDeploymentReleasePromptDeployment
|
23
|
+
description: typing.Optional[str] = None
|
24
|
+
release_tags: typing.List[ReleaseReleaseTag]
|
25
|
+
reviews: typing.List[SlimReleaseReview]
|
26
|
+
|
27
|
+
if IS_PYDANTIC_V2:
|
28
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
29
|
+
else:
|
30
|
+
|
31
|
+
class Config:
|
32
|
+
frozen = True
|
33
|
+
smart_union = True
|
34
|
+
extra = pydantic.Extra.allow
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
5
|
+
import typing
|
6
|
+
import pydantic
|
7
|
+
|
8
|
+
|
9
|
+
class PromptDeploymentReleasePromptDeployment(UniversalBaseModel):
|
10
|
+
name: str
|
11
|
+
|
12
|
+
if IS_PYDANTIC_V2:
|
13
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
14
|
+
else:
|
15
|
+
|
16
|
+
class Config:
|
17
|
+
frozen = True
|
18
|
+
smart_union = True
|
19
|
+
extra = pydantic.Extra.allow
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
5
|
+
import typing
|
6
|
+
import pydantic
|
7
|
+
|
8
|
+
|
9
|
+
class PromptDeploymentReleasePromptVersion(UniversalBaseModel):
|
10
|
+
id: str
|
11
|
+
|
12
|
+
if IS_PYDANTIC_V2:
|
13
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
14
|
+
else:
|
15
|
+
|
16
|
+
class Config:
|
17
|
+
frozen = True
|
18
|
+
smart_union = True
|
19
|
+
extra = pydantic.Extra.allow
|
vellum/workflows/inputs/base.py
CHANGED
@@ -4,6 +4,7 @@ from typing_extensions import dataclass_transform
|
|
4
4
|
from pydantic import GetCoreSchemaHandler
|
5
5
|
from pydantic_core import core_schema
|
6
6
|
|
7
|
+
from vellum.workflows.constants import undefined
|
7
8
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
8
9
|
from vellum.workflows.exceptions import WorkflowInitializationException
|
9
10
|
from vellum.workflows.references import ExternalInputReference, WorkflowInputReference
|
@@ -15,7 +16,7 @@ from vellum.workflows.types.utils import get_class_attr_names, infer_types
|
|
15
16
|
class _BaseInputsMeta(type):
|
16
17
|
def __getattribute__(cls, name: str) -> Any:
|
17
18
|
if not name.startswith("_") and name in cls.__annotations__ and issubclass(cls, BaseInputs):
|
18
|
-
instance = vars(cls).get(name)
|
19
|
+
instance = vars(cls).get(name, undefined)
|
19
20
|
types = infer_types(cls, name)
|
20
21
|
|
21
22
|
if getattr(cls, "__descriptor_class__", None) is ExternalInputReference:
|
@@ -38,6 +38,7 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
38
38
|
expand_raw: Optional[Sequence[str]] - Expandable raw fields to include in the response
|
39
39
|
metadata: Optional[Dict[str, Optional[Any]]] - The metadata to use for the Prompt Execution
|
40
40
|
request_options: Optional[RequestOptions] - The request options to use for the Prompt Execution
|
41
|
+
ml_model_fallback: Optional[Sequence[str]] - ML model fallbacks to use
|
41
42
|
"""
|
42
43
|
|
43
44
|
# Either the Prompt Deployment's UUID or its name.
|
@@ -50,6 +51,7 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
50
51
|
raw_overrides: Optional[RawPromptExecutionOverridesRequest] = OMIT
|
51
52
|
expand_raw: Optional[Sequence[str]] = OMIT
|
52
53
|
metadata: Optional[Dict[str, Optional[Any]]] = OMIT
|
54
|
+
ml_model_fallbacks: Optional[Sequence[str]] = OMIT
|
53
55
|
|
54
56
|
class Trigger(BasePromptNode.Trigger):
|
55
57
|
merge_behavior = MergeBehavior.AWAIT_ANY
|
@@ -2,6 +2,7 @@ from uuid import UUID
|
|
2
2
|
from typing import Any, ClassVar, Dict, Generic, List, Optional, Union, cast
|
3
3
|
|
4
4
|
from vellum import ChatHistoryInput, ChatMessage, JsonInput, MetricDefinitionInput, NumberInput, StringInput
|
5
|
+
from vellum.client import ApiError
|
5
6
|
from vellum.core import RequestOptions
|
6
7
|
from vellum.workflows.constants import LATEST_RELEASE_TAG
|
7
8
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
@@ -34,28 +35,50 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
34
35
|
|
35
36
|
class Outputs(BaseOutputs):
|
36
37
|
score: float
|
38
|
+
normalized_score: Optional[float]
|
37
39
|
log: Optional[str]
|
38
40
|
|
39
41
|
def run(self) -> Outputs:
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
42
|
+
try:
|
43
|
+
metric_execution = self._context.vellum_client.metric_definitions.execute_metric_definition(
|
44
|
+
self.metric_definition if isinstance(self.metric_definition, str) else str(self.metric_definition),
|
45
|
+
inputs=self._compile_metric_inputs(),
|
46
|
+
release_tag=self.release_tag,
|
47
|
+
request_options=self.request_options,
|
48
|
+
)
|
49
|
+
|
50
|
+
except ApiError:
|
51
|
+
raise NodeException(
|
52
|
+
code=WorkflowErrorCode.NODE_EXECUTION,
|
53
|
+
message="Failed to execute metric definition",
|
54
|
+
)
|
46
55
|
|
47
56
|
metric_outputs = {output.name: output.value for output in metric_execution.outputs}
|
48
57
|
|
49
|
-
|
58
|
+
SCORE_KEY = "score"
|
59
|
+
NORMALIZED_SCORE_KEY = "normalized_score"
|
60
|
+
LOG_KEY = "log"
|
61
|
+
|
62
|
+
score = metric_outputs.get(SCORE_KEY)
|
50
63
|
if not isinstance(score, float):
|
51
64
|
raise NodeException(
|
52
|
-
message="Metric execution must have one output named '
|
65
|
+
message=f"Metric execution must have one output named '{SCORE_KEY}' with type 'float'",
|
53
66
|
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
54
67
|
)
|
55
|
-
metric_outputs.pop(
|
68
|
+
metric_outputs.pop(SCORE_KEY)
|
69
|
+
|
70
|
+
if NORMALIZED_SCORE_KEY in metric_outputs:
|
71
|
+
normalized_score = metric_outputs.pop(NORMALIZED_SCORE_KEY)
|
72
|
+
if not isinstance(normalized_score, float):
|
73
|
+
raise NodeException(
|
74
|
+
message=f"Metric execution must have one output named '{NORMALIZED_SCORE_KEY}' with type 'float'",
|
75
|
+
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
normalized_score = None
|
56
79
|
|
57
|
-
if
|
58
|
-
log = metric_outputs.pop(
|
80
|
+
if LOG_KEY in metric_outputs:
|
81
|
+
log = metric_outputs.pop(LOG_KEY) or ""
|
59
82
|
if not isinstance(log, str):
|
60
83
|
raise NodeException(
|
61
84
|
message="Metric execution log output must be of type 'str'",
|
@@ -64,7 +87,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
64
87
|
else:
|
65
88
|
log = None
|
66
89
|
|
67
|
-
return self.Outputs(score=score, log=log, **metric_outputs)
|
90
|
+
return self.Outputs(score=score, normalized_score=normalized_score, log=log, **metric_outputs)
|
68
91
|
|
69
92
|
def _compile_metric_inputs(self) -> List[MetricDefinitionInput]:
|
70
93
|
# TODO: We may want to consolidate with prompt deployment input compilation
|
@@ -1,8 +1,11 @@
|
|
1
1
|
import pytest
|
2
2
|
|
3
3
|
from vellum import TestSuiteRunMetricNumberOutput
|
4
|
+
from vellum.client import ApiError
|
4
5
|
from vellum.client.types.metric_definition_execution import MetricDefinitionExecution
|
5
6
|
from vellum.client.types.test_suite_run_metric_string_output import TestSuiteRunMetricStringOutput
|
7
|
+
from vellum.workflows.errors import WorkflowErrorCode
|
8
|
+
from vellum.workflows.exceptions import NodeException
|
6
9
|
from vellum.workflows.nodes.displayable.guardrail_node.node import GuardrailNode
|
7
10
|
|
8
11
|
|
@@ -36,3 +39,88 @@ def test_run_guardrail_node__empty_log(vellum_client, log_value):
|
|
36
39
|
# THEN the workflow should have completed successfully
|
37
40
|
assert outputs.score == 0.6
|
38
41
|
assert outputs.log == ""
|
42
|
+
|
43
|
+
|
44
|
+
def test_run_guardrail_node__normalized_score(vellum_client):
|
45
|
+
"""Confirm that we can successfully invoke a Guardrail Node"""
|
46
|
+
|
47
|
+
# GIVEN a Guardrail Node
|
48
|
+
class MyGuard(GuardrailNode):
|
49
|
+
metric_definition = "example_metric_definition"
|
50
|
+
metric_inputs = {}
|
51
|
+
|
52
|
+
# AND we know that the guardrail node will return a normalized score
|
53
|
+
mock_metric_execution = MetricDefinitionExecution(
|
54
|
+
outputs=[
|
55
|
+
TestSuiteRunMetricNumberOutput(
|
56
|
+
name="score",
|
57
|
+
value=0.6,
|
58
|
+
),
|
59
|
+
TestSuiteRunMetricNumberOutput(
|
60
|
+
name="normalized_score",
|
61
|
+
value=1.0,
|
62
|
+
),
|
63
|
+
],
|
64
|
+
)
|
65
|
+
vellum_client.metric_definitions.execute_metric_definition.return_value = mock_metric_execution
|
66
|
+
|
67
|
+
# WHEN we run the Guardrail Node
|
68
|
+
outputs = MyGuard().run()
|
69
|
+
|
70
|
+
# THEN the workflow should have completed successfully
|
71
|
+
assert outputs.score == 0.6
|
72
|
+
assert outputs.normalized_score == 1.0
|
73
|
+
|
74
|
+
|
75
|
+
def test_run_guardrail_node__normalized_score_null(vellum_client):
|
76
|
+
# GIVEN a Guardrail Node
|
77
|
+
class MyGuard(GuardrailNode):
|
78
|
+
metric_definition = "example_metric_definition"
|
79
|
+
metric_inputs = {}
|
80
|
+
|
81
|
+
# AND we know that the guardrail node will return a normalized score that is None
|
82
|
+
mock_metric_execution = MetricDefinitionExecution(
|
83
|
+
outputs=[
|
84
|
+
TestSuiteRunMetricNumberOutput(
|
85
|
+
name="score",
|
86
|
+
value=0.6,
|
87
|
+
),
|
88
|
+
TestSuiteRunMetricNumberOutput(
|
89
|
+
name="normalized_score",
|
90
|
+
value=None,
|
91
|
+
),
|
92
|
+
],
|
93
|
+
)
|
94
|
+
vellum_client.metric_definitions.execute_metric_definition.return_value = mock_metric_execution
|
95
|
+
|
96
|
+
# WHEN we run the Guardrail Node
|
97
|
+
with pytest.raises(NodeException) as exc_info:
|
98
|
+
MyGuard().run()
|
99
|
+
|
100
|
+
# THEN we get an exception
|
101
|
+
assert exc_info.value.message == "Metric execution must have one output named 'normalized_score' with type 'float'"
|
102
|
+
assert exc_info.value.code == WorkflowErrorCode.INVALID_OUTPUTS
|
103
|
+
|
104
|
+
|
105
|
+
def test_run_guardrail_node__api_error(vellum_client):
|
106
|
+
# GIVEN a Guardrail Node
|
107
|
+
class MyGuard(GuardrailNode):
|
108
|
+
metric_definition = "example_metric_definition"
|
109
|
+
metric_inputs = {}
|
110
|
+
|
111
|
+
# AND the API client raises an ApiError when called
|
112
|
+
api_error = ApiError(status_code=503)
|
113
|
+
vellum_client.metric_definitions.execute_metric_definition.side_effect = api_error
|
114
|
+
|
115
|
+
# WHEN we run the Guardrail Node
|
116
|
+
with pytest.raises(NodeException) as exc_info:
|
117
|
+
MyGuard().run()
|
118
|
+
|
119
|
+
# THEN we get a NodeException with the appropriate error code
|
120
|
+
assert exc_info.value.code == WorkflowErrorCode.NODE_EXECUTION
|
121
|
+
assert "Failed to execute metric definition" in exc_info.value.message
|
122
|
+
|
123
|
+
# Verify the mock was called with the expected arguments
|
124
|
+
vellum_client.metric_definitions.execute_metric_definition.assert_called_once_with(
|
125
|
+
"example_metric_definition", inputs=[], release_tag="LATEST", request_options=None
|
126
|
+
)
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Iterator
|
2
|
+
from typing import Any, Dict, Iterator, Type, Union
|
3
3
|
|
4
|
+
from vellum.workflows.constants import undefined
|
4
5
|
from vellum.workflows.errors import WorkflowErrorCode
|
5
6
|
from vellum.workflows.exceptions import NodeException
|
6
7
|
from vellum.workflows.nodes.displayable.bases import BasePromptDeploymentNode as BasePromptDeploymentNode
|
@@ -11,7 +12,7 @@ from vellum.workflows.types.generics import StateType
|
|
11
12
|
|
12
13
|
class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
|
13
14
|
"""
|
14
|
-
Used to execute a Prompt Deployment and surface a string output for convenience.
|
15
|
+
Used to execute a Prompt Deployment and surface a string output and json output if applicable for convenience.
|
15
16
|
|
16
17
|
prompt_inputs: EntityInputsInterface - The inputs for the Prompt
|
17
18
|
deployment: Union[UUID, str] - Either the Prompt Deployment's UUID or its name.
|
@@ -33,9 +34,11 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
|
|
33
34
|
The outputs of the PromptDeploymentNode.
|
34
35
|
|
35
36
|
text: str - The result of the Prompt Execution
|
37
|
+
json: Optional[Dict[Any, Any]] - The result of the Prompt Execution in JSON format
|
36
38
|
"""
|
37
39
|
|
38
40
|
text: str
|
41
|
+
json: Union[Dict[Any, Any], Type[undefined]] = undefined
|
39
42
|
|
40
43
|
def run(self) -> Iterator[BaseOutput]:
|
41
44
|
outputs = yield from self._process_prompt_event_stream()
|
@@ -46,12 +49,18 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
|
|
46
49
|
)
|
47
50
|
|
48
51
|
string_outputs = []
|
52
|
+
json_output = None
|
53
|
+
|
49
54
|
for output in outputs:
|
50
55
|
if output.value is None:
|
51
56
|
continue
|
52
57
|
|
53
58
|
if output.type == "STRING":
|
54
59
|
string_outputs.append(output.value)
|
60
|
+
try:
|
61
|
+
json_output = json.loads(output.value)
|
62
|
+
except (json.JSONDecodeError, TypeError):
|
63
|
+
pass
|
55
64
|
elif output.type == "JSON":
|
56
65
|
string_outputs.append(json.dumps(output.value, indent=4))
|
57
66
|
elif output.type == "FUNCTION_CALL":
|
@@ -61,3 +70,6 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
|
|
61
70
|
|
62
71
|
value = "\n".join(string_outputs)
|
63
72
|
yield BaseOutput(name="text", value=value)
|
73
|
+
|
74
|
+
if json_output:
|
75
|
+
yield BaseOutput(name="json", value=json_output)
|