vellum-ai 0.14.16__py3-none-any.whl → 0.14.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +2 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +2 -0
- vellum/client/types/release.py +21 -0
- vellum/client/types/workflow_release_tag_read.py +7 -1
- vellum/prompts/blocks/compilation.py +14 -0
- vellum/types/release.py +3 -0
- vellum/workflows/events/workflow.py +15 -1
- vellum/workflows/nodes/bases/base.py +7 -7
- vellum/workflows/nodes/bases/base_adornment_node.py +2 -0
- vellum/workflows/nodes/core/retry_node/node.py +60 -40
- vellum/workflows/nodes/core/templating_node/node.py +2 -2
- vellum/workflows/nodes/core/try_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +4 -0
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +27 -1
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +298 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/node.py +24 -1
- vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +7 -1
- vellum/workflows/runner/runner.py +16 -1
- vellum/workflows/utils/tests/test_vellum_variables.py +7 -1
- vellum/workflows/utils/vellum_variables.py +4 -0
- {vellum_ai-0.14.16.dist-info → vellum_ai-0.14.18.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.16.dist-info → vellum_ai-0.14.18.dist-info}/RECORD +39 -34
- vellum_ee/workflows/display/nodes/base_node_display.py +35 -29
- vellum_ee/workflows/display/nodes/get_node_display_class.py +0 -9
- vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +38 -18
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +6 -0
- vellum_ee/workflows/display/nodes/vellum/templating_node.py +6 -7
- vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +97 -0
- vellum_ee/workflows/display/nodes/vellum/utils.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -1
- vellum_ee/workflows/display/vellum.py +1 -148
- vellum_ee/workflows/display/workflows/base_workflow_display.py +1 -1
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +61 -17
- vellum_ee/workflows/tests/test_display_meta.py +10 -10
- {vellum_ai-0.14.16.dist-info → vellum_ai-0.14.18.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.16.dist-info → vellum_ai-0.14.18.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.16.dist-info → vellum_ai-0.14.18.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
import pytest
|
2
|
+
import json
|
3
|
+
from unittest import mock
|
4
|
+
from uuid import uuid4
|
5
|
+
from typing import Any, Iterator, List
|
6
|
+
|
7
|
+
from vellum import (
|
8
|
+
ChatMessagePromptBlock,
|
9
|
+
JinjaPromptBlock,
|
10
|
+
PlainTextPromptBlock,
|
11
|
+
PromptBlock,
|
12
|
+
PromptParameters,
|
13
|
+
RichTextPromptBlock,
|
14
|
+
VariablePromptBlock,
|
15
|
+
)
|
16
|
+
from vellum.client.types.execute_prompt_event import ExecutePromptEvent
|
17
|
+
from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
|
18
|
+
from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
|
19
|
+
from vellum.client.types.prompt_output import PromptOutput
|
20
|
+
from vellum.client.types.prompt_request_string_input import PromptRequestStringInput
|
21
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
22
|
+
from vellum.workflows.errors import WorkflowErrorCode
|
23
|
+
from vellum.workflows.exceptions import NodeException
|
24
|
+
from vellum.workflows.inputs import BaseInputs
|
25
|
+
from vellum.workflows.nodes import InlinePromptNode
|
26
|
+
from vellum.workflows.nodes.displayable.bases.inline_prompt_node import BaseInlinePromptNode
|
27
|
+
from vellum.workflows.state import BaseState
|
28
|
+
from vellum.workflows.state.base import StateMeta
|
29
|
+
|
30
|
+
|
31
|
+
def test_validation_with_missing_variables():
|
32
|
+
"""Test that validation correctly identifies missing variables."""
|
33
|
+
test_blocks: List[PromptBlock] = [
|
34
|
+
VariablePromptBlock(input_variable="required_var1"),
|
35
|
+
VariablePromptBlock(input_variable="required_var2"),
|
36
|
+
RichTextPromptBlock(
|
37
|
+
blocks=[
|
38
|
+
PlainTextPromptBlock(text="Some text"),
|
39
|
+
VariablePromptBlock(input_variable="required_var3"),
|
40
|
+
],
|
41
|
+
),
|
42
|
+
JinjaPromptBlock(template="Template without variables"),
|
43
|
+
ChatMessagePromptBlock(
|
44
|
+
chat_role="USER",
|
45
|
+
blocks=[
|
46
|
+
RichTextPromptBlock(
|
47
|
+
blocks=[
|
48
|
+
PlainTextPromptBlock(text="Nested text"),
|
49
|
+
VariablePromptBlock(input_variable="required_var4"),
|
50
|
+
],
|
51
|
+
),
|
52
|
+
],
|
53
|
+
),
|
54
|
+
]
|
55
|
+
|
56
|
+
# GIVEN a BaseInlinePromptNode
|
57
|
+
class TestNode(BaseInlinePromptNode):
|
58
|
+
ml_model = "test-model"
|
59
|
+
blocks = test_blocks
|
60
|
+
prompt_inputs = {
|
61
|
+
"required_var1": "value1",
|
62
|
+
# required_var2 is missing
|
63
|
+
# required_var3 is missing
|
64
|
+
# required_var4 is missing
|
65
|
+
}
|
66
|
+
|
67
|
+
# WHEN the node is run
|
68
|
+
node = TestNode()
|
69
|
+
with pytest.raises(NodeException) as excinfo:
|
70
|
+
list(node.run())
|
71
|
+
|
72
|
+
# THEN the node raises the correct NodeException
|
73
|
+
assert excinfo.value.code == WorkflowErrorCode.INVALID_INPUTS
|
74
|
+
assert "required_var2" in str(excinfo.value)
|
75
|
+
assert "required_var3" in str(excinfo.value)
|
76
|
+
assert "required_var4" in str(excinfo.value)
|
77
|
+
|
78
|
+
|
79
|
+
def test_validation_with_all_variables_provided(vellum_adhoc_prompt_client):
|
80
|
+
"""Test that validation passes when all variables are provided."""
|
81
|
+
test_blocks: List[PromptBlock] = [
|
82
|
+
VariablePromptBlock(input_variable="required_var1"),
|
83
|
+
VariablePromptBlock(input_variable="required_var2"),
|
84
|
+
RichTextPromptBlock(
|
85
|
+
blocks=[
|
86
|
+
PlainTextPromptBlock(text="Some text"),
|
87
|
+
VariablePromptBlock(input_variable="required_var3"),
|
88
|
+
],
|
89
|
+
),
|
90
|
+
JinjaPromptBlock(template="Template without variables"),
|
91
|
+
ChatMessagePromptBlock(
|
92
|
+
chat_role="USER",
|
93
|
+
blocks=[
|
94
|
+
RichTextPromptBlock(
|
95
|
+
blocks=[
|
96
|
+
PlainTextPromptBlock(text="Nested text"),
|
97
|
+
VariablePromptBlock(input_variable="required_var4"),
|
98
|
+
],
|
99
|
+
),
|
100
|
+
],
|
101
|
+
),
|
102
|
+
]
|
103
|
+
|
104
|
+
# GIVEN a BaseInlinePromptNode
|
105
|
+
class TestNode(BaseInlinePromptNode):
|
106
|
+
ml_model = "test-model"
|
107
|
+
blocks = test_blocks
|
108
|
+
prompt_inputs = {
|
109
|
+
"required_var1": "value1",
|
110
|
+
"required_var2": "value2",
|
111
|
+
"required_var3": "value3",
|
112
|
+
"required_var4": "value4",
|
113
|
+
}
|
114
|
+
|
115
|
+
expected_outputs: List[PromptOutput] = [
|
116
|
+
StringVellumValue(value="Test response"),
|
117
|
+
]
|
118
|
+
|
119
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
120
|
+
execution_id = str(uuid4())
|
121
|
+
events: List[ExecutePromptEvent] = [
|
122
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
123
|
+
FulfilledExecutePromptEvent(
|
124
|
+
execution_id=execution_id,
|
125
|
+
outputs=expected_outputs,
|
126
|
+
),
|
127
|
+
]
|
128
|
+
yield from events
|
129
|
+
|
130
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
131
|
+
|
132
|
+
# WHEN the node is run
|
133
|
+
node = TestNode()
|
134
|
+
list(node.run())
|
135
|
+
|
136
|
+
# THEN the prompt is executed with the correct inputs
|
137
|
+
mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
|
138
|
+
assert mock_api.call_count == 1
|
139
|
+
assert mock_api.call_args.kwargs["input_values"] == [
|
140
|
+
PromptRequestStringInput(key="required_var1", type="STRING", value="value1"),
|
141
|
+
PromptRequestStringInput(key="required_var2", type="STRING", value="value2"),
|
142
|
+
PromptRequestStringInput(key="required_var3", type="STRING", value="value3"),
|
143
|
+
PromptRequestStringInput(key="required_var4", type="STRING", value="value4"),
|
144
|
+
]
|
145
|
+
|
146
|
+
|
147
|
+
def test_validation_with_extra_variables(vellum_adhoc_prompt_client):
|
148
|
+
"""Test that validation passes when extra variables are provided."""
|
149
|
+
test_blocks: List[PromptBlock] = [
|
150
|
+
VariablePromptBlock(input_variable="required_var"),
|
151
|
+
]
|
152
|
+
|
153
|
+
# GIVEN a BaseInlinePromptNode
|
154
|
+
class TestNode(BaseInlinePromptNode):
|
155
|
+
ml_model = "test-model"
|
156
|
+
blocks = test_blocks
|
157
|
+
prompt_inputs = {
|
158
|
+
"required_var": "value",
|
159
|
+
"extra_var": "extra_value", # This is not required
|
160
|
+
}
|
161
|
+
|
162
|
+
expected_outputs: List[PromptOutput] = [
|
163
|
+
StringVellumValue(value="Test response"),
|
164
|
+
]
|
165
|
+
|
166
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
167
|
+
execution_id = str(uuid4())
|
168
|
+
events: List[ExecutePromptEvent] = [
|
169
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
170
|
+
FulfilledExecutePromptEvent(
|
171
|
+
execution_id=execution_id,
|
172
|
+
outputs=expected_outputs,
|
173
|
+
),
|
174
|
+
]
|
175
|
+
yield from events
|
176
|
+
|
177
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
178
|
+
|
179
|
+
# WHEN the node is run
|
180
|
+
node = TestNode()
|
181
|
+
list(node.run())
|
182
|
+
|
183
|
+
# THEN the prompt is executed with the correct inputs
|
184
|
+
mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
|
185
|
+
assert mock_api.call_count == 1
|
186
|
+
assert mock_api.call_args.kwargs["input_values"] == [
|
187
|
+
PromptRequestStringInput(key="required_var", type="STRING", value="value"),
|
188
|
+
PromptRequestStringInput(key="extra_var", type="STRING", value="extra_value"),
|
189
|
+
]
|
190
|
+
|
191
|
+
|
192
|
+
def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
|
193
|
+
"""Confirm that InlinePromptNodes output the expected JSON when run."""
|
194
|
+
|
195
|
+
# GIVEN a node that subclasses InlinePromptNode
|
196
|
+
class Inputs(BaseInputs):
|
197
|
+
input: str
|
198
|
+
|
199
|
+
class State(BaseState):
|
200
|
+
pass
|
201
|
+
|
202
|
+
class MyInlinePromptNode(InlinePromptNode):
|
203
|
+
ml_model = "gpt-4o"
|
204
|
+
blocks = []
|
205
|
+
parameters = PromptParameters(
|
206
|
+
stop=[],
|
207
|
+
temperature=0.0,
|
208
|
+
max_tokens=4096,
|
209
|
+
top_p=1.0,
|
210
|
+
top_k=0,
|
211
|
+
frequency_penalty=0.0,
|
212
|
+
presence_penalty=0.0,
|
213
|
+
logit_bias=None,
|
214
|
+
custom_parameters={
|
215
|
+
"json_mode": False,
|
216
|
+
"json_schema": {
|
217
|
+
"name": "get_result",
|
218
|
+
"schema": {
|
219
|
+
"type": "object",
|
220
|
+
"required": ["result"],
|
221
|
+
"properties": {"result": {"type": "string", "description": ""}},
|
222
|
+
},
|
223
|
+
},
|
224
|
+
},
|
225
|
+
)
|
226
|
+
|
227
|
+
# AND a known JSON response from invoking an inline prompt
|
228
|
+
expected_json = {"result": "Hello, world!"}
|
229
|
+
expected_outputs: List[PromptOutput] = [
|
230
|
+
StringVellumValue(value=json.dumps(expected_json)),
|
231
|
+
]
|
232
|
+
|
233
|
+
def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
|
234
|
+
execution_id = str(uuid4())
|
235
|
+
events: List[ExecutePromptEvent] = [
|
236
|
+
InitiatedExecutePromptEvent(execution_id=execution_id),
|
237
|
+
FulfilledExecutePromptEvent(
|
238
|
+
execution_id=execution_id,
|
239
|
+
outputs=expected_outputs,
|
240
|
+
),
|
241
|
+
]
|
242
|
+
yield from events
|
243
|
+
|
244
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
|
245
|
+
|
246
|
+
# WHEN the node is run
|
247
|
+
node = MyInlinePromptNode(
|
248
|
+
state=State(
|
249
|
+
meta=StateMeta(workflow_inputs=Inputs(input="Generate JSON.")),
|
250
|
+
)
|
251
|
+
)
|
252
|
+
outputs = [o for o in node.run()]
|
253
|
+
|
254
|
+
# THEN the node should have produced the outputs we expect
|
255
|
+
results_output = outputs[0]
|
256
|
+
assert results_output.name == "results"
|
257
|
+
assert results_output.value == expected_outputs
|
258
|
+
|
259
|
+
text_output = outputs[1]
|
260
|
+
assert text_output.name == "text"
|
261
|
+
assert text_output.value == '{"result": "Hello, world!"}'
|
262
|
+
|
263
|
+
json_output = outputs[2]
|
264
|
+
assert json_output.name == "json"
|
265
|
+
assert json_output.value == expected_json
|
266
|
+
|
267
|
+
# AND we should have made the expected call to Vellum search
|
268
|
+
vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.assert_called_once_with(
|
269
|
+
blocks=[],
|
270
|
+
expand_meta=Ellipsis,
|
271
|
+
functions=None,
|
272
|
+
input_values=[],
|
273
|
+
input_variables=[],
|
274
|
+
ml_model="gpt-4o",
|
275
|
+
parameters=PromptParameters(
|
276
|
+
stop=[],
|
277
|
+
temperature=0.0,
|
278
|
+
max_tokens=4096,
|
279
|
+
top_p=1.0,
|
280
|
+
top_k=0,
|
281
|
+
frequency_penalty=0.0,
|
282
|
+
presence_penalty=0.0,
|
283
|
+
logit_bias=None,
|
284
|
+
custom_parameters={
|
285
|
+
"json_mode": False,
|
286
|
+
"json_schema": {
|
287
|
+
"name": "get_result",
|
288
|
+
"schema": {
|
289
|
+
"type": "object",
|
290
|
+
"required": ["result"],
|
291
|
+
"properties": {"result": {"type": "string", "description": ""}},
|
292
|
+
},
|
293
|
+
},
|
294
|
+
},
|
295
|
+
),
|
296
|
+
request_options=mock.ANY,
|
297
|
+
settings=None,
|
298
|
+
)
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Iterator
|
2
|
+
from typing import Any, Dict, Iterator, Type, Union
|
3
3
|
|
4
|
+
from vellum.workflows.constants import undefined
|
4
5
|
from vellum.workflows.errors import WorkflowErrorCode
|
5
6
|
from vellum.workflows.exceptions import NodeException
|
6
7
|
from vellum.workflows.nodes.displayable.bases import BaseInlinePromptNode as BaseInlinePromptNode
|
@@ -30,9 +31,11 @@ class InlinePromptNode(BaseInlinePromptNode[StateType]):
|
|
30
31
|
The outputs of the InlinePromptNode.
|
31
32
|
|
32
33
|
text: str - The result of the Prompt Execution
|
34
|
+
json: Optional[Dict[Any, Any]] - The result of the Prompt Execution in JSON format
|
33
35
|
"""
|
34
36
|
|
35
37
|
text: str
|
38
|
+
json: Union[Dict[Any, Any], Type[undefined]] = undefined
|
36
39
|
|
37
40
|
def run(self) -> Iterator[BaseOutput]:
|
38
41
|
outputs = yield from self._process_prompt_event_stream()
|
@@ -43,14 +46,31 @@ class InlinePromptNode(BaseInlinePromptNode[StateType]):
|
|
43
46
|
)
|
44
47
|
|
45
48
|
string_outputs = []
|
49
|
+
json_output = None
|
50
|
+
|
51
|
+
should_parse_json = False
|
52
|
+
if hasattr(self, "parameters"):
|
53
|
+
custom_params = self.parameters.custom_parameters
|
54
|
+
if custom_params and isinstance(custom_params, dict):
|
55
|
+
json_schema = custom_params.get("json_schema", {})
|
56
|
+
if (isinstance(json_schema, dict) and "schema" in json_schema) or custom_params.get("json_mode", {}):
|
57
|
+
should_parse_json = True
|
58
|
+
|
46
59
|
for output in outputs:
|
47
60
|
if output.value is None:
|
48
61
|
continue
|
49
62
|
|
50
63
|
if output.type == "STRING":
|
51
64
|
string_outputs.append(output.value)
|
65
|
+
if should_parse_json:
|
66
|
+
try:
|
67
|
+
parsed_json = json.loads(output.value)
|
68
|
+
json_output = parsed_json
|
69
|
+
except (json.JSONDecodeError, TypeError):
|
70
|
+
pass
|
52
71
|
elif output.type == "JSON":
|
53
72
|
string_outputs.append(json.dumps(output.value, indent=4))
|
73
|
+
json_output = output.value
|
54
74
|
elif output.type == "FUNCTION_CALL":
|
55
75
|
string_outputs.append(output.value.model_dump_json(indent=4))
|
56
76
|
else:
|
@@ -58,3 +78,6 @@ class InlinePromptNode(BaseInlinePromptNode[StateType]):
|
|
58
78
|
|
59
79
|
value = "\n".join(string_outputs)
|
60
80
|
yield BaseOutput(name="text", value=value)
|
81
|
+
|
82
|
+
if json_output:
|
83
|
+
yield BaseOutput(name="json", value=json_output)
|
@@ -27,6 +27,7 @@ from vellum import (
|
|
27
27
|
StringVellumValue,
|
28
28
|
VellumAudio,
|
29
29
|
VellumError,
|
30
|
+
VellumImage,
|
30
31
|
)
|
31
32
|
from vellum.prompts.blocks.compilation import compile_prompt_blocks
|
32
33
|
from vellum.prompts.blocks.types import CompiledChatMessagePromptBlock
|
@@ -202,7 +203,7 @@ class OpenAIChatCompletionNode(InlinePromptNode[StateType]):
|
|
202
203
|
json_content_item: ChatCompletionContentPartTextParam = {"type": "text", "text": json.dumps(json_value)}
|
203
204
|
content.append(json_content_item)
|
204
205
|
elif block.content.type == "IMAGE":
|
205
|
-
image_value = cast(
|
206
|
+
image_value = cast(VellumImage, block.content.value)
|
206
207
|
image_content_item: ChatCompletionContentPartImageParam = {
|
207
208
|
"type": "image_url",
|
208
209
|
"image_url": {"url": image_value.src},
|
@@ -251,6 +252,11 @@ class OpenAIChatCompletionNode(InlinePromptNode[StateType]):
|
|
251
252
|
}
|
252
253
|
|
253
254
|
content.append(audio_content_item)
|
255
|
+
elif block.content.type == "DOCUMENT":
|
256
|
+
raise NodeException(
|
257
|
+
code=WorkflowErrorCode.PROVIDER_ERROR,
|
258
|
+
message="Document chat message content type is not currently supported",
|
259
|
+
)
|
254
260
|
else:
|
255
261
|
raise NodeException(
|
256
262
|
code=WorkflowErrorCode.INTERNAL_ERROR,
|
@@ -43,7 +43,7 @@ from vellum.workflows.events.workflow import (
|
|
43
43
|
WorkflowExecutionSnapshottedEvent,
|
44
44
|
WorkflowExecutionStreamingBody,
|
45
45
|
)
|
46
|
-
from vellum.workflows.exceptions import NodeException
|
46
|
+
from vellum.workflows.exceptions import NodeException, WorkflowInitializationException
|
47
47
|
from vellum.workflows.nodes.bases import BaseNode
|
48
48
|
from vellum.workflows.nodes.bases.base import NodeRunResponse
|
49
49
|
from vellum.workflows.nodes.mocks import MockNodeExecutionArg
|
@@ -332,6 +332,18 @@ class WorkflowRunner(Generic[StateType]):
|
|
332
332
|
parent=parent_context,
|
333
333
|
)
|
334
334
|
)
|
335
|
+
except WorkflowInitializationException as e:
|
336
|
+
self._workflow_event_inner_queue.put(
|
337
|
+
NodeExecutionRejectedEvent(
|
338
|
+
trace_id=node.state.meta.trace_id,
|
339
|
+
span_id=span_id,
|
340
|
+
body=NodeExecutionRejectedBody(
|
341
|
+
node_definition=node.__class__,
|
342
|
+
error=e.error,
|
343
|
+
),
|
344
|
+
parent=parent_context,
|
345
|
+
)
|
346
|
+
)
|
335
347
|
except Exception as e:
|
336
348
|
logger.exception(f"An unexpected error occurred while running node {node.__class__.__name__}")
|
337
349
|
|
@@ -563,6 +575,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
563
575
|
except NodeException as e:
|
564
576
|
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error))
|
565
577
|
return
|
578
|
+
except WorkflowInitializationException as e:
|
579
|
+
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error))
|
580
|
+
return
|
566
581
|
except Exception:
|
567
582
|
err_message = f"An unexpected error occurred while initializing node {node_cls.__name__}"
|
568
583
|
logger.exception(err_message)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import pytest
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from vellum import ChatMessage, SearchResult
|
4
|
+
from vellum import ChatMessage, SearchResult, VellumAudio, VellumDocument, VellumImage
|
5
5
|
from vellum.workflows.types.core import Json
|
6
6
|
from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_variable_type
|
7
7
|
|
@@ -21,6 +21,12 @@ from vellum.workflows.utils.vellum_variables import primitive_type_to_vellum_var
|
|
21
21
|
(Optional[List[SearchResult]], "SEARCH_RESULTS"),
|
22
22
|
(Json, "JSON"),
|
23
23
|
(Optional[Json], "JSON"),
|
24
|
+
(VellumDocument, "DOCUMENT"),
|
25
|
+
(Optional[VellumDocument], "DOCUMENT"),
|
26
|
+
(VellumAudio, "AUDIO"),
|
27
|
+
(Optional[VellumAudio], "AUDIO"),
|
28
|
+
(VellumImage, "IMAGE"),
|
29
|
+
(Optional[VellumImage], "IMAGE"),
|
24
30
|
],
|
25
31
|
)
|
26
32
|
def test_primitive_type_to_vellum_variable_type(type_, expected):
|
@@ -10,6 +10,8 @@ from vellum import (
|
|
10
10
|
SearchResultRequest,
|
11
11
|
VellumAudio,
|
12
12
|
VellumAudioRequest,
|
13
|
+
VellumDocument,
|
14
|
+
VellumDocumentRequest,
|
13
15
|
VellumError,
|
14
16
|
VellumErrorRequest,
|
15
17
|
VellumImage,
|
@@ -62,6 +64,8 @@ def primitive_type_to_vellum_variable_type(type_: Union[Type, BaseDescriptor]) -
|
|
62
64
|
return "IMAGE"
|
63
65
|
elif _is_type_optionally_in(type_, (VellumAudio, VellumAudioRequest)):
|
64
66
|
return "AUDIO"
|
67
|
+
elif _is_type_optionally_in(type_, (VellumDocument, VellumDocumentRequest)):
|
68
|
+
return "DOCUMENT"
|
65
69
|
elif _is_type_optionally_in(type_, (VellumError, VellumErrorRequest)):
|
66
70
|
return "ERROR"
|
67
71
|
elif _is_type_optionally_in(type_, (List[ChatMessage], List[ChatMessageRequest])):
|