vellum-ai 1.3.8__py3-none-any.whl → 1.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/reference.md +71 -0
- vellum/client/resources/workflows/client.py +80 -0
- vellum/client/resources/workflows/raw_client.py +98 -0
- vellum/client/types/vellum_error.py +2 -1
- vellum/client/types/vellum_error_request.py +2 -1
- vellum/workflows/errors/types.py +3 -1
- vellum/workflows/events/tests/test_event.py +1 -0
- vellum/workflows/exceptions.py +11 -2
- vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +2 -0
- vellum/workflows/types/definition.py +32 -1
- vellum/workflows/utils/tests/test_vellum_variables.py +7 -1
- vellum/workflows/utils/vellum_variables.py +42 -3
- {vellum_ai-1.3.8.dist-info → vellum_ai-1.3.10.dist-info}/METADATA +1 -1
- {vellum_ai-1.3.8.dist-info → vellum_ai-1.3.10.dist-info}/RECORD +38 -38
- vellum_ee/workflows/display/editor/types.py +2 -0
- vellum_ee/workflows/display/nodes/base_node_display.py +42 -14
- vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +64 -0
- vellum_ee/workflows/display/nodes/vellum/final_output_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +12 -12
- vellum_ee/workflows/display/nodes/vellum/tests/test_tool_calling_node.py +4 -4
- vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -1
- vellum_ee/workflows/display/tests/test_base_workflow_display.py +46 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +8 -8
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_inline_workflow_serialization.py +2 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +2 -1
- vellum_ee/workflows/display/utils/events.py +7 -1
- vellum_ee/workflows/display/utils/expressions.py +35 -42
- vellum_ee/workflows/display/utils/tests/test_events.py +4 -4
- vellum_ee/workflows/display/workflows/base_workflow_display.py +1 -1
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +10 -10
- {vellum_ai-1.3.8.dist-info → vellum_ai-1.3.10.dist-info}/LICENSE +0 -0
- {vellum_ai-1.3.8.dist-info → vellum_ai-1.3.10.dist-info}/WHEEL +0 -0
- {vellum_ai-1.3.8.dist-info → vellum_ai-1.3.10.dist-info}/entry_points.txt +0 -0
@@ -27,10 +27,10 @@ class BaseClientWrapper:
|
|
27
27
|
|
28
28
|
def get_headers(self) -> typing.Dict[str, str]:
|
29
29
|
headers: typing.Dict[str, str] = {
|
30
|
-
"User-Agent": "vellum-ai/1.3.
|
30
|
+
"User-Agent": "vellum-ai/1.3.10",
|
31
31
|
"X-Fern-Language": "Python",
|
32
32
|
"X-Fern-SDK-Name": "vellum-ai",
|
33
|
-
"X-Fern-SDK-Version": "1.3.
|
33
|
+
"X-Fern-SDK-Version": "1.3.10",
|
34
34
|
**(self.get_custom_headers() or {}),
|
35
35
|
}
|
36
36
|
if self._api_version is not None:
|
vellum/client/reference.md
CHANGED
@@ -6426,6 +6426,77 @@ client.workflow_sandboxes.list_workflow_sandbox_examples()
|
|
6426
6426
|
</details>
|
6427
6427
|
|
6428
6428
|
## Workflows
|
6429
|
+
<details><summary><code>client.workflows.<a href="src/vellum/resources/workflows/client.py">serialize_workflow_files</a>(...)</code></summary>
|
6430
|
+
<dl>
|
6431
|
+
<dd>
|
6432
|
+
|
6433
|
+
#### 📝 Description
|
6434
|
+
|
6435
|
+
<dl>
|
6436
|
+
<dd>
|
6437
|
+
|
6438
|
+
<dl>
|
6439
|
+
<dd>
|
6440
|
+
|
6441
|
+
Serialize files
|
6442
|
+
</dd>
|
6443
|
+
</dl>
|
6444
|
+
</dd>
|
6445
|
+
</dl>
|
6446
|
+
|
6447
|
+
#### 🔌 Usage
|
6448
|
+
|
6449
|
+
<dl>
|
6450
|
+
<dd>
|
6451
|
+
|
6452
|
+
<dl>
|
6453
|
+
<dd>
|
6454
|
+
|
6455
|
+
```python
|
6456
|
+
from vellum import Vellum
|
6457
|
+
|
6458
|
+
client = Vellum(
|
6459
|
+
api_version="YOUR_API_VERSION",
|
6460
|
+
api_key="YOUR_API_KEY",
|
6461
|
+
)
|
6462
|
+
client.workflows.serialize_workflow_files(
|
6463
|
+
files={"files": {"key": "value"}},
|
6464
|
+
)
|
6465
|
+
|
6466
|
+
```
|
6467
|
+
</dd>
|
6468
|
+
</dl>
|
6469
|
+
</dd>
|
6470
|
+
</dl>
|
6471
|
+
|
6472
|
+
#### ⚙️ Parameters
|
6473
|
+
|
6474
|
+
<dl>
|
6475
|
+
<dd>
|
6476
|
+
|
6477
|
+
<dl>
|
6478
|
+
<dd>
|
6479
|
+
|
6480
|
+
**files:** `typing.Dict[str, typing.Optional[typing.Any]]`
|
6481
|
+
|
6482
|
+
</dd>
|
6483
|
+
</dl>
|
6484
|
+
|
6485
|
+
<dl>
|
6486
|
+
<dd>
|
6487
|
+
|
6488
|
+
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
|
6489
|
+
|
6490
|
+
</dd>
|
6491
|
+
</dl>
|
6492
|
+
</dd>
|
6493
|
+
</dl>
|
6494
|
+
|
6495
|
+
|
6496
|
+
</dd>
|
6497
|
+
</dl>
|
6498
|
+
</details>
|
6499
|
+
|
6429
6500
|
## WorkspaceSecrets
|
6430
6501
|
<details><summary><code>client.workspace_secrets.<a href="src/vellum/resources/workspace_secrets/client.py">retrieve</a>(...)</code></summary>
|
6431
6502
|
<dl>
|
@@ -136,6 +136,42 @@ class WorkflowsClient:
|
|
136
136
|
)
|
137
137
|
return _response.data
|
138
138
|
|
139
|
+
def serialize_workflow_files(
|
140
|
+
self,
|
141
|
+
*,
|
142
|
+
files: typing.Dict[str, typing.Optional[typing.Any]],
|
143
|
+
request_options: typing.Optional[RequestOptions] = None,
|
144
|
+
) -> typing.Dict[str, typing.Optional[typing.Any]]:
|
145
|
+
"""
|
146
|
+
Serialize files
|
147
|
+
|
148
|
+
Parameters
|
149
|
+
----------
|
150
|
+
files : typing.Dict[str, typing.Optional[typing.Any]]
|
151
|
+
|
152
|
+
request_options : typing.Optional[RequestOptions]
|
153
|
+
Request-specific configuration.
|
154
|
+
|
155
|
+
Returns
|
156
|
+
-------
|
157
|
+
typing.Dict[str, typing.Optional[typing.Any]]
|
158
|
+
|
159
|
+
|
160
|
+
Examples
|
161
|
+
--------
|
162
|
+
from vellum import Vellum
|
163
|
+
|
164
|
+
client = Vellum(
|
165
|
+
api_version="YOUR_API_VERSION",
|
166
|
+
api_key="YOUR_API_KEY",
|
167
|
+
)
|
168
|
+
client.workflows.serialize_workflow_files(
|
169
|
+
files={"files": {"key": "value"}},
|
170
|
+
)
|
171
|
+
"""
|
172
|
+
_response = self._raw_client.serialize_workflow_files(files=files, request_options=request_options)
|
173
|
+
return _response.data
|
174
|
+
|
139
175
|
|
140
176
|
class AsyncWorkflowsClient:
|
141
177
|
def __init__(self, *, client_wrapper: AsyncClientWrapper):
|
@@ -255,3 +291,47 @@ class AsyncWorkflowsClient:
|
|
255
291
|
request_options=request_options,
|
256
292
|
)
|
257
293
|
return _response.data
|
294
|
+
|
295
|
+
async def serialize_workflow_files(
|
296
|
+
self,
|
297
|
+
*,
|
298
|
+
files: typing.Dict[str, typing.Optional[typing.Any]],
|
299
|
+
request_options: typing.Optional[RequestOptions] = None,
|
300
|
+
) -> typing.Dict[str, typing.Optional[typing.Any]]:
|
301
|
+
"""
|
302
|
+
Serialize files
|
303
|
+
|
304
|
+
Parameters
|
305
|
+
----------
|
306
|
+
files : typing.Dict[str, typing.Optional[typing.Any]]
|
307
|
+
|
308
|
+
request_options : typing.Optional[RequestOptions]
|
309
|
+
Request-specific configuration.
|
310
|
+
|
311
|
+
Returns
|
312
|
+
-------
|
313
|
+
typing.Dict[str, typing.Optional[typing.Any]]
|
314
|
+
|
315
|
+
|
316
|
+
Examples
|
317
|
+
--------
|
318
|
+
import asyncio
|
319
|
+
|
320
|
+
from vellum import AsyncVellum
|
321
|
+
|
322
|
+
client = AsyncVellum(
|
323
|
+
api_version="YOUR_API_VERSION",
|
324
|
+
api_key="YOUR_API_KEY",
|
325
|
+
)
|
326
|
+
|
327
|
+
|
328
|
+
async def main() -> None:
|
329
|
+
await client.workflows.serialize_workflow_files(
|
330
|
+
files={"files": {"key": "value"}},
|
331
|
+
)
|
332
|
+
|
333
|
+
|
334
|
+
asyncio.run(main())
|
335
|
+
"""
|
336
|
+
_response = await self._raw_client.serialize_workflow_files(files=files, request_options=request_options)
|
337
|
+
return _response.data
|
@@ -181,6 +181,55 @@ class RawWorkflowsClient:
|
|
181
181
|
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
|
182
182
|
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
|
183
183
|
|
184
|
+
def serialize_workflow_files(
|
185
|
+
self,
|
186
|
+
*,
|
187
|
+
files: typing.Dict[str, typing.Optional[typing.Any]],
|
188
|
+
request_options: typing.Optional[RequestOptions] = None,
|
189
|
+
) -> HttpResponse[typing.Dict[str, typing.Optional[typing.Any]]]:
|
190
|
+
"""
|
191
|
+
Serialize files
|
192
|
+
|
193
|
+
Parameters
|
194
|
+
----------
|
195
|
+
files : typing.Dict[str, typing.Optional[typing.Any]]
|
196
|
+
|
197
|
+
request_options : typing.Optional[RequestOptions]
|
198
|
+
Request-specific configuration.
|
199
|
+
|
200
|
+
Returns
|
201
|
+
-------
|
202
|
+
HttpResponse[typing.Dict[str, typing.Optional[typing.Any]]]
|
203
|
+
|
204
|
+
"""
|
205
|
+
_response = self._client_wrapper.httpx_client.request(
|
206
|
+
"v1/workflows/serialize",
|
207
|
+
base_url=self._client_wrapper.get_environment().default,
|
208
|
+
method="POST",
|
209
|
+
json={
|
210
|
+
"files": files,
|
211
|
+
},
|
212
|
+
headers={
|
213
|
+
"content-type": "application/json",
|
214
|
+
},
|
215
|
+
request_options=request_options,
|
216
|
+
omit=OMIT,
|
217
|
+
)
|
218
|
+
try:
|
219
|
+
if 200 <= _response.status_code < 300:
|
220
|
+
_data = typing.cast(
|
221
|
+
typing.Dict[str, typing.Optional[typing.Any]],
|
222
|
+
parse_obj_as(
|
223
|
+
type_=typing.Dict[str, typing.Optional[typing.Any]], # type: ignore
|
224
|
+
object_=_response.json(),
|
225
|
+
),
|
226
|
+
)
|
227
|
+
return HttpResponse(response=_response, data=_data)
|
228
|
+
_response_json = _response.json()
|
229
|
+
except JSONDecodeError:
|
230
|
+
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
|
231
|
+
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
|
232
|
+
|
184
233
|
|
185
234
|
class AsyncRawWorkflowsClient:
|
186
235
|
def __init__(self, *, client_wrapper: AsyncClientWrapper):
|
@@ -343,3 +392,52 @@ class AsyncRawWorkflowsClient:
|
|
343
392
|
except JSONDecodeError:
|
344
393
|
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
|
345
394
|
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
|
395
|
+
|
396
|
+
async def serialize_workflow_files(
|
397
|
+
self,
|
398
|
+
*,
|
399
|
+
files: typing.Dict[str, typing.Optional[typing.Any]],
|
400
|
+
request_options: typing.Optional[RequestOptions] = None,
|
401
|
+
) -> AsyncHttpResponse[typing.Dict[str, typing.Optional[typing.Any]]]:
|
402
|
+
"""
|
403
|
+
Serialize files
|
404
|
+
|
405
|
+
Parameters
|
406
|
+
----------
|
407
|
+
files : typing.Dict[str, typing.Optional[typing.Any]]
|
408
|
+
|
409
|
+
request_options : typing.Optional[RequestOptions]
|
410
|
+
Request-specific configuration.
|
411
|
+
|
412
|
+
Returns
|
413
|
+
-------
|
414
|
+
AsyncHttpResponse[typing.Dict[str, typing.Optional[typing.Any]]]
|
415
|
+
|
416
|
+
"""
|
417
|
+
_response = await self._client_wrapper.httpx_client.request(
|
418
|
+
"v1/workflows/serialize",
|
419
|
+
base_url=self._client_wrapper.get_environment().default,
|
420
|
+
method="POST",
|
421
|
+
json={
|
422
|
+
"files": files,
|
423
|
+
},
|
424
|
+
headers={
|
425
|
+
"content-type": "application/json",
|
426
|
+
},
|
427
|
+
request_options=request_options,
|
428
|
+
omit=OMIT,
|
429
|
+
)
|
430
|
+
try:
|
431
|
+
if 200 <= _response.status_code < 300:
|
432
|
+
_data = typing.cast(
|
433
|
+
typing.Dict[str, typing.Optional[typing.Any]],
|
434
|
+
parse_obj_as(
|
435
|
+
type_=typing.Dict[str, typing.Optional[typing.Any]], # type: ignore
|
436
|
+
object_=_response.json(),
|
437
|
+
),
|
438
|
+
)
|
439
|
+
return AsyncHttpResponse(response=_response, data=_data)
|
440
|
+
_response_json = _response.json()
|
441
|
+
except JSONDecodeError:
|
442
|
+
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
|
443
|
+
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
|
@@ -8,8 +8,9 @@ from .vellum_error_code_enum import VellumErrorCodeEnum
|
|
8
8
|
|
9
9
|
|
10
10
|
class VellumError(UniversalBaseModel):
|
11
|
-
message: str
|
12
11
|
code: VellumErrorCodeEnum
|
12
|
+
message: str
|
13
|
+
raw_data: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = None
|
13
14
|
|
14
15
|
if IS_PYDANTIC_V2:
|
15
16
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -8,8 +8,9 @@ from .vellum_error_code_enum import VellumErrorCodeEnum
|
|
8
8
|
|
9
9
|
|
10
10
|
class VellumErrorRequest(UniversalBaseModel):
|
11
|
-
message: str
|
12
11
|
code: VellumErrorCodeEnum
|
12
|
+
message: str
|
13
|
+
raw_data: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = None
|
13
14
|
|
14
15
|
if IS_PYDANTIC_V2:
|
15
16
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
vellum/workflows/errors/types.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
from enum import Enum
|
3
3
|
import logging
|
4
|
-
from typing import Any, Dict
|
4
|
+
from typing import Any, Dict, Optional
|
5
5
|
|
6
6
|
from vellum.client.types.vellum_error import VellumError
|
7
7
|
from vellum.client.types.vellum_error_code_enum import VellumErrorCodeEnum
|
@@ -30,6 +30,7 @@ class WorkflowErrorCode(Enum):
|
|
30
30
|
class WorkflowError:
|
31
31
|
message: str
|
32
32
|
code: WorkflowErrorCode
|
33
|
+
raw_data: Optional[Dict[str, Any]] = None
|
33
34
|
|
34
35
|
def __contains__(self, item: Any) -> bool:
|
35
36
|
return item in self.message
|
@@ -55,6 +56,7 @@ def vellum_error_to_workflow_error(error: VellumError) -> WorkflowError:
|
|
55
56
|
return WorkflowError(
|
56
57
|
message=error.message,
|
57
58
|
code=workflow_error_code,
|
59
|
+
raw_data=error.raw_data or {},
|
58
60
|
)
|
59
61
|
|
60
62
|
|
vellum/workflows/exceptions.py
CHANGED
@@ -1,10 +1,18 @@
|
|
1
|
+
from typing import Any, Dict, Optional
|
2
|
+
|
1
3
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
2
4
|
|
3
5
|
|
4
6
|
class NodeException(Exception):
|
5
|
-
def __init__(
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
message: str,
|
10
|
+
code: WorkflowErrorCode = WorkflowErrorCode.INTERNAL_ERROR,
|
11
|
+
raw_data: Optional[Dict[str, Any]] = None,
|
12
|
+
):
|
6
13
|
self.message = message
|
7
14
|
self.code = code
|
15
|
+
self.raw_data = raw_data
|
8
16
|
super().__init__(message)
|
9
17
|
|
10
18
|
@property
|
@@ -12,11 +20,12 @@ class NodeException(Exception):
|
|
12
20
|
return WorkflowError(
|
13
21
|
message=self.message,
|
14
22
|
code=self.code,
|
23
|
+
raw_data=self.raw_data,
|
15
24
|
)
|
16
25
|
|
17
26
|
@staticmethod
|
18
27
|
def of(workflow_error: WorkflowError) -> "NodeException":
|
19
|
-
return NodeException(message=workflow_error.message, code=workflow_error.code)
|
28
|
+
return NodeException(message=workflow_error.message, code=workflow_error.code, raw_data=workflow_error.raw_data)
|
20
29
|
|
21
30
|
|
22
31
|
class WorkflowInitializationException(Exception):
|
@@ -114,6 +114,7 @@ def test_inline_text_prompt_node__catch_provider_error(vellum_adhoc_prompt_clien
|
|
114
114
|
expected_error = VellumError(
|
115
115
|
message="OpenAI failed",
|
116
116
|
code="PROVIDER_ERROR",
|
117
|
+
raw_data={"type": "ERROR", "error": {"message": "Something went wrong"}},
|
117
118
|
)
|
118
119
|
|
119
120
|
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
@@ -144,6 +145,7 @@ def test_inline_text_prompt_node__catch_provider_error(vellum_adhoc_prompt_clien
|
|
144
145
|
value=SdkVellumError(
|
145
146
|
message="OpenAI failed",
|
146
147
|
code=WorkflowErrorCode.PROVIDER_ERROR,
|
148
|
+
raw_data={"type": "ERROR", "error": {"message": "Something went wrong"}},
|
147
149
|
),
|
148
150
|
)
|
149
151
|
in outputs
|
@@ -4,8 +4,9 @@ from types import FrameType
|
|
4
4
|
from uuid import UUID
|
5
5
|
from typing import Annotated, Any, Dict, Literal, Optional, Union
|
6
6
|
|
7
|
-
from pydantic import BeforeValidator
|
7
|
+
from pydantic import BeforeValidator, SerializationInfo, model_serializer
|
8
8
|
|
9
|
+
from vellum import Vellum
|
9
10
|
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
10
11
|
from vellum.client.types.code_resource_definition import CodeResourceDefinition as ClientCodeResourceDefinition
|
11
12
|
from vellum.workflows.constants import AuthorizationType
|
@@ -78,6 +79,10 @@ class DeploymentDefinition(UniversalBaseModel):
|
|
78
79
|
deployment: str
|
79
80
|
release_tag: str = "LATEST"
|
80
81
|
|
82
|
+
# hydrated fields
|
83
|
+
name: Optional[str] = None
|
84
|
+
description: Optional[str] = None
|
85
|
+
|
81
86
|
def _is_uuid(self) -> bool:
|
82
87
|
"""Check if the deployment field is a valid UUID."""
|
83
88
|
try:
|
@@ -100,6 +105,32 @@ class DeploymentDefinition(UniversalBaseModel):
|
|
100
105
|
return self.deployment
|
101
106
|
return None
|
102
107
|
|
108
|
+
@model_serializer(mode="wrap")
|
109
|
+
def _serialize(self, handler, info: SerializationInfo):
|
110
|
+
"""Allow Pydantic to serialize directly given a `client` in context.
|
111
|
+
|
112
|
+
Falls back to the default serialization when client is not provided.
|
113
|
+
"""
|
114
|
+
context = info.context if info and hasattr(info, "context") else {}
|
115
|
+
client: Optional[Vellum] = context.get("client") if context else None
|
116
|
+
|
117
|
+
if client:
|
118
|
+
release = client.workflow_deployments.retrieve_workflow_deployment_release(
|
119
|
+
self.deployment, self.release_tag
|
120
|
+
)
|
121
|
+
self.name = release.deployment.name or self.deployment
|
122
|
+
self.description = release.description or f"Workflow Deployment for {self.deployment}"
|
123
|
+
|
124
|
+
return {
|
125
|
+
"type": "WORKFLOW_DEPLOYMENT",
|
126
|
+
"name": self.name,
|
127
|
+
"description": self.description,
|
128
|
+
"deployment": self.deployment,
|
129
|
+
"release_tag": self.release_tag,
|
130
|
+
}
|
131
|
+
|
132
|
+
return handler(self)
|
133
|
+
|
103
134
|
|
104
135
|
class ComposioToolDefinition(UniversalBaseModel):
|
105
136
|
"""Represents a specific Composio action that can be used in Tool Calling Node"""
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import pytest
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from vellum import ChatMessage, SearchResult, VellumAudio, VellumDocument, VellumImage
|
4
|
+
from vellum import ChatMessage, SearchResult, VellumAudio, VellumDocument, VellumImage, VellumValue
|
5
5
|
from vellum.workflows.types.core import Json
|
6
6
|
from vellum.workflows.utils.vellum_variables import (
|
7
7
|
primitive_type_to_vellum_variable_type,
|
@@ -30,6 +30,12 @@ from vellum.workflows.utils.vellum_variables import (
|
|
30
30
|
(Optional[VellumAudio], "AUDIO"),
|
31
31
|
(VellumImage, "IMAGE"),
|
32
32
|
(Optional[VellumImage], "IMAGE"),
|
33
|
+
(list[ChatMessage], "CHAT_HISTORY"),
|
34
|
+
(Optional[list[ChatMessage]], "CHAT_HISTORY"),
|
35
|
+
(list[SearchResult], "SEARCH_RESULTS"),
|
36
|
+
(Optional[list[SearchResult]], "SEARCH_RESULTS"),
|
37
|
+
(list[VellumValue], "ARRAY"),
|
38
|
+
(Optional[list[VellumValue]], "ARRAY"),
|
33
39
|
],
|
34
40
|
)
|
35
41
|
def test_primitive_type_to_vellum_variable_type(type_, expected):
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import typing
|
2
|
-
from typing import List, Tuple, Type, Union, get_args, get_origin
|
2
|
+
from typing import Any, List, Tuple, Type, Union, get_args, get_origin
|
3
3
|
|
4
4
|
from vellum import (
|
5
5
|
ChatMessage,
|
@@ -86,6 +86,11 @@ def primitive_type_to_vellum_variable_type(type_: Union[Type, BaseDescriptor]) -
|
|
86
86
|
return "DOCUMENT"
|
87
87
|
elif _is_type_optionally_in(type_, (VellumError, VellumErrorRequest)):
|
88
88
|
return "ERROR"
|
89
|
+
|
90
|
+
builtin_list_type = _builtin_list_to_vellum_type(type_)
|
91
|
+
if builtin_list_type:
|
92
|
+
return builtin_list_type
|
93
|
+
|
89
94
|
elif _is_type_optionally_in(type_, (List[ChatMessage], List[ChatMessageRequest])):
|
90
95
|
return "CHAT_HISTORY"
|
91
96
|
elif _is_type_optionally_in(type_, (List[SearchResult], List[SearchResultRequest])):
|
@@ -128,7 +133,7 @@ def vellum_variable_type_to_openapi_type(vellum_type: VellumVariableType) -> str
|
|
128
133
|
return "object"
|
129
134
|
|
130
135
|
|
131
|
-
def _is_type_optionally_equal(type_: Type, target_type:
|
136
|
+
def _is_type_optionally_equal(type_: Type, target_type: Any) -> bool:
|
132
137
|
if type_ == target_type:
|
133
138
|
return True
|
134
139
|
|
@@ -147,7 +152,7 @@ def _is_type_optionally_equal(type_: Type, target_type: Type) -> bool:
|
|
147
152
|
return _is_type_optionally_equal(source_type, target_type)
|
148
153
|
|
149
154
|
|
150
|
-
def _is_type_optionally_in(type_: Type, target_types: Tuple[
|
155
|
+
def _is_type_optionally_in(type_: Type, target_types: Tuple[Any, ...]) -> bool:
|
151
156
|
return any(_is_type_optionally_equal(type_, target_type) for target_type in target_types)
|
152
157
|
|
153
158
|
|
@@ -181,3 +186,37 @@ def _is_subtype(source_type: Type, target_type: Type) -> bool:
|
|
181
186
|
return True
|
182
187
|
|
183
188
|
return False
|
189
|
+
|
190
|
+
|
191
|
+
def _unwrap_optional(type_: Type) -> Type:
|
192
|
+
origin = get_origin(type_)
|
193
|
+
if origin is typing.Union:
|
194
|
+
args = get_args(type_)
|
195
|
+
if len(args) == 2:
|
196
|
+
if args[1] is type(None):
|
197
|
+
return args[0]
|
198
|
+
if args[0] is type(None):
|
199
|
+
return args[1]
|
200
|
+
return type_
|
201
|
+
|
202
|
+
|
203
|
+
def _builtin_list_to_vellum_type(type_: Type) -> Union[str, None]:
|
204
|
+
candidate = _unwrap_optional(type_)
|
205
|
+
origin = get_origin(candidate)
|
206
|
+
if origin in (list, typing.List):
|
207
|
+
args = get_args(candidate)
|
208
|
+
if len(args) == 1:
|
209
|
+
item_type = args[0]
|
210
|
+
if _is_type_optionally_equal(item_type, ChatMessage) or _is_type_optionally_equal(
|
211
|
+
item_type, ChatMessageRequest
|
212
|
+
):
|
213
|
+
return "CHAT_HISTORY"
|
214
|
+
if _is_type_optionally_equal(item_type, SearchResult) or _is_type_optionally_equal(
|
215
|
+
item_type, SearchResultRequest
|
216
|
+
):
|
217
|
+
return "SEARCH_RESULTS"
|
218
|
+
if _is_type_optionally_equal(item_type, VellumValue) or _is_type_optionally_equal(
|
219
|
+
item_type, VellumValueRequest
|
220
|
+
):
|
221
|
+
return "ARRAY"
|
222
|
+
return None
|