vellum-ai 0.14.65__py3-none-any.whl → 0.14.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/README.md +1 -1
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +2767 -0
- vellum/client/types/document_read.py +0 -1
- vellum/client/types/folder_entity_prompt_sandbox_data.py +1 -0
- vellum/client/types/folder_entity_workflow_sandbox_data.py +1 -0
- vellum/workflows/expressions/accessor.py +22 -5
- vellum/workflows/expressions/tests/test_accessor.py +189 -0
- vellum/workflows/nodes/bases/base.py +30 -39
- vellum/workflows/nodes/bases/tests/test_base_node.py +48 -2
- vellum/workflows/nodes/displayable/api_node/node.py +3 -1
- vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +32 -0
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +28 -0
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +26 -23
- vellum/workflows/nodes/displayable/conditional_node/node.py +1 -2
- vellum/workflows/nodes/displayable/final_output_node/node.py +2 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/node.py +4 -14
- vellum/workflows/nodes/displayable/search_node/node.py +8 -0
- vellum/workflows/nodes/displayable/search_node/tests/test_node.py +19 -0
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +4 -13
- vellum/workflows/runner/runner.py +13 -17
- vellum/workflows/state/base.py +0 -4
- {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/METADATA +2 -2
- {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/RECORD +33 -30
- vellum_cli/image_push.py +62 -7
- vellum_cli/pull.py +38 -9
- vellum_cli/tests/test_image_push_error_handling.py +184 -0
- vellum_cli/tests/test_pull.py +12 -9
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_inline_workflow_serialization.py +661 -0
- vellum_ee/workflows/display/utils/expressions.py +17 -0
- {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/entry_points.txt +0 -0
@@ -32,7 +32,6 @@ class DocumentRead(UniversalBaseModel):
|
|
32
32
|
"""
|
33
33
|
|
34
34
|
original_file_url: typing.Optional[str] = None
|
35
|
-
processed_file_url: typing.Optional[str] = None
|
36
35
|
document_to_document_indexes: typing.List[DocumentDocumentToDocumentIndex]
|
37
36
|
metadata: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = pydantic.Field(default=None)
|
38
37
|
"""
|
@@ -35,20 +35,37 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
35
35
|
if isinstance(self._field, int):
|
36
36
|
raise InvalidExpressionException("Cannot access field by index on a dataclass")
|
37
37
|
|
38
|
-
|
38
|
+
try:
|
39
|
+
return getattr(base, self._field)
|
40
|
+
except AttributeError:
|
41
|
+
raise InvalidExpressionException(f"Field '{self._field}' not found on dataclass {type(base).__name__}")
|
39
42
|
|
40
43
|
if isinstance(base, BaseModel):
|
41
44
|
if isinstance(self._field, int):
|
42
45
|
raise InvalidExpressionException("Cannot access field by index on a BaseModel")
|
43
46
|
|
44
|
-
|
47
|
+
try:
|
48
|
+
return getattr(base, self._field)
|
49
|
+
except AttributeError:
|
50
|
+
raise InvalidExpressionException(f"Field '{self._field}' not found on BaseModel {type(base).__name__}")
|
45
51
|
|
46
52
|
if isinstance(base, Mapping):
|
47
|
-
|
53
|
+
try:
|
54
|
+
return base[self._field]
|
55
|
+
except KeyError:
|
56
|
+
raise InvalidExpressionException(f"Key '{self._field}' not found in mapping")
|
48
57
|
|
49
58
|
if isinstance(base, Sequence):
|
50
|
-
|
51
|
-
|
59
|
+
try:
|
60
|
+
index = int(self._field)
|
61
|
+
return base[index]
|
62
|
+
except (IndexError, ValueError):
|
63
|
+
if isinstance(self._field, int) or (isinstance(self._field, str) and self._field.lstrip("-").isdigit()):
|
64
|
+
raise InvalidExpressionException(
|
65
|
+
f"Index {self._field} is out of bounds for sequence of length {len(base)}"
|
66
|
+
)
|
67
|
+
else:
|
68
|
+
raise InvalidExpressionException(f"Invalid index '{self._field}' for sequence access")
|
52
69
|
|
53
70
|
raise InvalidExpressionException(f"Cannot get field {self._field} from {base}")
|
54
71
|
|
@@ -0,0 +1,189 @@
|
|
1
|
+
import pytest
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
6
|
+
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
7
|
+
from vellum.workflows.expressions.accessor import AccessorExpression
|
8
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
9
|
+
from vellum.workflows.state.base import BaseState
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass
|
13
|
+
class TestDataclass:
|
14
|
+
name: str
|
15
|
+
value: int
|
16
|
+
|
17
|
+
|
18
|
+
class TestBaseModel(BaseModel):
|
19
|
+
name: str
|
20
|
+
value: int
|
21
|
+
|
22
|
+
|
23
|
+
class TestState(BaseState):
|
24
|
+
pass
|
25
|
+
|
26
|
+
|
27
|
+
def test_accessor_expression_dict_valid_key():
|
28
|
+
state = TestState()
|
29
|
+
base_ref = ConstantValueReference({"name": "test", "value": 42})
|
30
|
+
accessor = AccessorExpression(base=base_ref, field="name")
|
31
|
+
|
32
|
+
result = accessor.resolve(state)
|
33
|
+
|
34
|
+
assert result == "test"
|
35
|
+
|
36
|
+
|
37
|
+
def test_accessor_expression_dict_invalid_key():
|
38
|
+
state = TestState()
|
39
|
+
base_ref = ConstantValueReference({"name": "test", "value": 42})
|
40
|
+
accessor = AccessorExpression(base=base_ref, field="missing_key")
|
41
|
+
|
42
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
43
|
+
accessor.resolve(state)
|
44
|
+
|
45
|
+
assert "Key 'missing_key' not found in mapping" in str(exc_info.value)
|
46
|
+
|
47
|
+
|
48
|
+
def test_accessor_expression_list_valid_index():
|
49
|
+
state = TestState()
|
50
|
+
base_ref = ConstantValueReference(["first", "second", "third"])
|
51
|
+
accessor = AccessorExpression(base=base_ref, field=1)
|
52
|
+
|
53
|
+
result = accessor.resolve(state)
|
54
|
+
|
55
|
+
assert result == "second"
|
56
|
+
|
57
|
+
|
58
|
+
def test_accessor_expression_list_invalid_index():
|
59
|
+
state = TestState()
|
60
|
+
base_ref = ConstantValueReference(["first", "second"])
|
61
|
+
accessor = AccessorExpression(base=base_ref, field=5)
|
62
|
+
|
63
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
64
|
+
accessor.resolve(state)
|
65
|
+
|
66
|
+
assert "Index 5 is out of bounds for sequence of length 2" in str(exc_info.value)
|
67
|
+
|
68
|
+
|
69
|
+
def test_accessor_expression_list_negative_index():
|
70
|
+
state = TestState()
|
71
|
+
base_ref = ConstantValueReference(["first", "second", "third"])
|
72
|
+
accessor = AccessorExpression(base=base_ref, field=-1)
|
73
|
+
|
74
|
+
result = accessor.resolve(state)
|
75
|
+
|
76
|
+
assert result == "third"
|
77
|
+
|
78
|
+
|
79
|
+
def test_accessor_expression_list_invalid_negative_index():
|
80
|
+
state = TestState()
|
81
|
+
base_ref = ConstantValueReference(["first", "second"])
|
82
|
+
accessor = AccessorExpression(base=base_ref, field=-5)
|
83
|
+
|
84
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
85
|
+
accessor.resolve(state)
|
86
|
+
|
87
|
+
assert "Index -5 is out of bounds for sequence of length 2" in str(exc_info.value)
|
88
|
+
|
89
|
+
|
90
|
+
def test_accessor_expression_list_string_index():
|
91
|
+
state = TestState()
|
92
|
+
base_ref = ConstantValueReference(["first", "second", "third"])
|
93
|
+
accessor = AccessorExpression(base=base_ref, field="1")
|
94
|
+
|
95
|
+
result = accessor.resolve(state)
|
96
|
+
|
97
|
+
assert result == "second"
|
98
|
+
|
99
|
+
|
100
|
+
def test_accessor_expression_list_invalid_string_index():
|
101
|
+
state = TestState()
|
102
|
+
base_ref = ConstantValueReference(["first", "second"])
|
103
|
+
accessor = AccessorExpression(base=base_ref, field="invalid")
|
104
|
+
|
105
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
106
|
+
accessor.resolve(state)
|
107
|
+
|
108
|
+
assert "Invalid index 'invalid' for sequence access" in str(exc_info.value)
|
109
|
+
|
110
|
+
|
111
|
+
def test_accessor_expression_dataclass_valid_field():
|
112
|
+
state = TestState()
|
113
|
+
test_obj = TestDataclass(name="test", value=42)
|
114
|
+
base_ref = ConstantValueReference(test_obj)
|
115
|
+
accessor = AccessorExpression(base=base_ref, field="name")
|
116
|
+
|
117
|
+
result = accessor.resolve(state)
|
118
|
+
|
119
|
+
assert result == "test"
|
120
|
+
|
121
|
+
|
122
|
+
def test_accessor_expression_dataclass_invalid_field():
|
123
|
+
state = TestState()
|
124
|
+
test_obj = TestDataclass(name="test", value=42)
|
125
|
+
base_ref = ConstantValueReference(test_obj)
|
126
|
+
accessor = AccessorExpression(base=base_ref, field="missing_field")
|
127
|
+
|
128
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
129
|
+
accessor.resolve(state)
|
130
|
+
|
131
|
+
assert "Field 'missing_field' not found on dataclass TestDataclass" in str(exc_info.value)
|
132
|
+
|
133
|
+
|
134
|
+
def test_accessor_expression_basemodel_valid_field():
|
135
|
+
state = TestState()
|
136
|
+
test_obj = TestBaseModel(name="test", value=42)
|
137
|
+
base_ref = ConstantValueReference(test_obj)
|
138
|
+
accessor = AccessorExpression(base=base_ref, field="name")
|
139
|
+
|
140
|
+
result = accessor.resolve(state)
|
141
|
+
|
142
|
+
assert result == "test"
|
143
|
+
|
144
|
+
|
145
|
+
def test_accessor_expression_basemodel_invalid_field():
|
146
|
+
state = TestState()
|
147
|
+
test_obj = TestBaseModel(name="test", value=42)
|
148
|
+
base_ref = ConstantValueReference(test_obj)
|
149
|
+
accessor = AccessorExpression(base=base_ref, field="missing_field")
|
150
|
+
|
151
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
152
|
+
accessor.resolve(state)
|
153
|
+
|
154
|
+
assert "Field 'missing_field' not found on BaseModel TestBaseModel" in str(exc_info.value)
|
155
|
+
|
156
|
+
|
157
|
+
def test_accessor_expression_dataclass_index_access():
|
158
|
+
state = TestState()
|
159
|
+
test_obj = TestDataclass(name="test", value=42)
|
160
|
+
base_ref = ConstantValueReference(test_obj)
|
161
|
+
accessor = AccessorExpression(base=base_ref, field=0)
|
162
|
+
|
163
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
164
|
+
accessor.resolve(state)
|
165
|
+
|
166
|
+
assert "Cannot access field by index on a dataclass" in str(exc_info.value)
|
167
|
+
|
168
|
+
|
169
|
+
def test_accessor_expression_basemodel_index_access():
|
170
|
+
state = TestState()
|
171
|
+
test_obj = TestBaseModel(name="test", value=42)
|
172
|
+
base_ref = ConstantValueReference(test_obj)
|
173
|
+
accessor = AccessorExpression(base=base_ref, field=0)
|
174
|
+
|
175
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
176
|
+
accessor.resolve(state)
|
177
|
+
|
178
|
+
assert "Cannot access field by index on a BaseModel" in str(exc_info.value)
|
179
|
+
|
180
|
+
|
181
|
+
def test_accessor_expression_unsupported_type():
|
182
|
+
state = TestState()
|
183
|
+
base_ref = ConstantValueReference(42)
|
184
|
+
accessor = AccessorExpression(base=base_ref, field="field")
|
185
|
+
|
186
|
+
with pytest.raises(InvalidExpressionException) as exc_info:
|
187
|
+
accessor.resolve(state)
|
188
|
+
|
189
|
+
assert "Cannot get field field from 42" in str(exc_info.value)
|
@@ -10,6 +10,7 @@ from vellum.workflows.constants import undefined
|
|
10
10
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
11
11
|
from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
|
12
12
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
13
|
+
from vellum.workflows.events.node import NodeExecutionStreamingEvent
|
13
14
|
from vellum.workflows.exceptions import NodeException
|
14
15
|
from vellum.workflows.graph import Graph
|
15
16
|
from vellum.workflows.graph.graph import GraphTarget
|
@@ -92,7 +93,7 @@ class BaseNodeMeta(ABCMeta):
|
|
92
93
|
if issubclass(base, BaseNode):
|
93
94
|
ports_dct = {p.name: Port(default=p.default) for p in base.Ports}
|
94
95
|
ports_dct["__module__"] = dct["__module__"]
|
95
|
-
dct["Ports"] = type(f"{name}.Ports", (
|
96
|
+
dct["Ports"] = type(f"{name}.Ports", (base.Ports,), ports_dct)
|
96
97
|
break
|
97
98
|
|
98
99
|
if "Execution" not in dct:
|
@@ -357,41 +358,15 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
|
|
357
358
|
state.meta.node_execution_cache._dependencies_invoked[execution_id].add(invoked_by)
|
358
359
|
return execution_id
|
359
360
|
|
360
|
-
# For AWAIT_ANY in workflows, we
|
361
|
-
#
|
362
|
-
# 2. The node is being re-executed because of a loop
|
363
|
-
# For case 1, we need to remove the fork id from the node_to_fork_id mapping
|
364
|
-
# For case 2, we need to check if the node is in a loop
|
361
|
+
# For AWAIT_ANY in workflows, we need to detect if the node is in a loop
|
362
|
+
# If the node is in a loop, we can trigger the node again
|
365
363
|
in_loop = False
|
366
|
-
# Default to true, will be set to false if the merged node has already been triggered
|
367
|
-
should_retrigger = True
|
368
364
|
if cls.merge_behavior == MergeBehavior.AWAIT_ANY:
|
369
|
-
# Get the
|
370
|
-
|
371
|
-
|
372
|
-
#
|
373
|
-
if
|
374
|
-
fork_id = state.meta.node_execution_cache.__node_to_fork_id__.get(invoked_by_node, None)
|
375
|
-
if fork_id:
|
376
|
-
# If the invoked by node has a fork id and that fork id is in __all_fork_ids__
|
377
|
-
# We will
|
378
|
-
# 1. remove the fork id from __all_fork_ids__
|
379
|
-
# 2. remove the fork id from the __node_to_fork_id__ mapping
|
380
|
-
# else (this mean the fork has already been triggered)
|
381
|
-
# remove the id from the node_to_fork_id mapping and not retrigger again
|
382
|
-
all_fork_ids = state.meta.node_execution_cache.__all_fork_ids__
|
383
|
-
if fork_id in all_fork_ids:
|
384
|
-
# When the next forked node merge, it will not trigger the node again
|
385
|
-
# We should consider adding a lock here to prevent race condition
|
386
|
-
all_fork_ids.remove(fork_id)
|
387
|
-
state.meta.node_execution_cache.__node_to_fork_id__.pop(invoked_by_node, None)
|
388
|
-
else:
|
389
|
-
should_retrigger = False
|
390
|
-
state.meta.node_execution_cache.__node_to_fork_id__.pop(invoked_by_node, None)
|
391
|
-
|
392
|
-
# If should_retrigger is false, then we will not trigger the node already
|
393
|
-
# So we don't need to check loop behavior
|
394
|
-
if should_retrigger:
|
365
|
+
# Get the latest fulfilled execution ID of current node
|
366
|
+
fulfilled_stack = state.meta.node_execution_cache._node_executions_fulfilled[cls.node_class]
|
367
|
+
current_latest_fulfilled_id = fulfilled_stack.peek() if not fulfilled_stack.is_empty() else None
|
368
|
+
# If the current node has not been fulfilled yet, we don't need to check for loop
|
369
|
+
if current_latest_fulfilled_id is not None:
|
395
370
|
# Trace back through the dependency chain to detect if this node triggers itself
|
396
371
|
visited = set()
|
397
372
|
current_execution_id = invoked_by
|
@@ -417,15 +392,18 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
|
|
417
392
|
current_execution_id
|
418
393
|
)
|
419
394
|
|
420
|
-
# If we've found our target node class in the chain
|
395
|
+
# If we've found our target node class in the chain
|
421
396
|
if current_node_class == cls.node_class:
|
422
|
-
|
397
|
+
# Check if the execution id is the same as the latest fulfilled execution id
|
398
|
+
# If yes, we're in a loop
|
399
|
+
if current_execution_id == current_latest_fulfilled_id:
|
400
|
+
in_loop = True
|
401
|
+
# If not, current node has been triggered by other node,
|
402
|
+
# we can break out of the loop
|
423
403
|
break
|
424
404
|
|
425
405
|
for queued_node_execution_id in state.meta.node_execution_cache._node_executions_queued[cls.node_class]:
|
426
|
-
|
427
|
-
# So we don't need to trigger the node again
|
428
|
-
if not should_retrigger or (
|
406
|
+
if (
|
429
407
|
invoked_by not in state.meta.node_execution_cache._dependencies_invoked[queued_node_execution_id]
|
430
408
|
and not in_loop
|
431
409
|
):
|
@@ -509,3 +487,16 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
|
|
509
487
|
|
510
488
|
def __repr__(self) -> str:
|
511
489
|
return str(self.__class__)
|
490
|
+
|
491
|
+
__simulates_workflow_output__ = False
|
492
|
+
|
493
|
+
def __directly_emit_workflow_output__(
|
494
|
+
self, event: NodeExecutionStreamingEvent, workflow_output_descriptor: OutputReference
|
495
|
+
) -> bool:
|
496
|
+
"""
|
497
|
+
In the legacy workflow runner, there was support for emitting streaming workflow outputs for prompt nodes
|
498
|
+
connected to terminal nodes. These two private methods provides a hacky, intentionally short-lived workaround
|
499
|
+
for us to enable this until we can directly reference prompt outputs from the UI.
|
500
|
+
"""
|
501
|
+
|
502
|
+
return False
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import pytest
|
2
2
|
from uuid import UUID
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Optional, Set
|
4
4
|
|
5
5
|
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
6
6
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
@@ -9,7 +9,9 @@ from vellum.workflows.descriptors.tests.test_utils import FixtureState
|
|
9
9
|
from vellum.workflows.inputs.base import BaseInputs
|
10
10
|
from vellum.workflows.nodes import FinalOutputNode
|
11
11
|
from vellum.workflows.nodes.bases.base import BaseNode
|
12
|
-
from vellum.workflows.outputs.base import BaseOutputs
|
12
|
+
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
13
|
+
from vellum.workflows.ports.port import Port
|
14
|
+
from vellum.workflows.references.constant import ConstantValueReference
|
13
15
|
from vellum.workflows.references.node import NodeReference
|
14
16
|
from vellum.workflows.references.output import OutputReference
|
15
17
|
from vellum.workflows.state.base import BaseState, StateMeta
|
@@ -333,3 +335,47 @@ def test_base_node__node_reference_of_inherited_annotation():
|
|
333
335
|
# THEN the node reference is of the correct type
|
334
336
|
assert isinstance(node_reference, NodeReference)
|
335
337
|
assert node_reference.name == "foo"
|
338
|
+
|
339
|
+
|
340
|
+
def test_base_node__ports_inheritance():
|
341
|
+
# GIVEN a node with one port
|
342
|
+
class MyNode(BaseNode):
|
343
|
+
class Ports(BaseNode.Ports):
|
344
|
+
foo = Port.on_if(ConstantValueReference(True))
|
345
|
+
|
346
|
+
def __lt__(self, output: BaseOutput) -> Set[Port]:
|
347
|
+
return {self.foo}
|
348
|
+
|
349
|
+
# AND a node that inherits from MyNode
|
350
|
+
class InheritedNode(MyNode):
|
351
|
+
pass
|
352
|
+
|
353
|
+
# WHEN we collect the ports
|
354
|
+
ports = [port.name for port in InheritedNode.Ports]
|
355
|
+
|
356
|
+
# THEN the inheritance is correct
|
357
|
+
inherited_ports = InheritedNode.Ports()
|
358
|
+
assert inherited_ports.__lt__(BaseOutput(name="foo")) == {InheritedNode.Ports.foo}
|
359
|
+
|
360
|
+
# AND the ports are correct
|
361
|
+
assert ports == ["foo"]
|
362
|
+
|
363
|
+
|
364
|
+
def test_base_node__ports_inheritance__cumulative_ports():
|
365
|
+
# GIVEN a node with one port
|
366
|
+
class MyNode(BaseNode):
|
367
|
+
class Ports(BaseNode.Ports):
|
368
|
+
foo = Port.on_if(ConstantValueReference(True))
|
369
|
+
|
370
|
+
# AND a node that inherits from MyNode with another port
|
371
|
+
class InheritedNode(MyNode):
|
372
|
+
class Ports(MyNode.Ports):
|
373
|
+
bar = Port.on_if(ConstantValueReference(True))
|
374
|
+
|
375
|
+
# WHEN we collect the ports
|
376
|
+
ports = [port.name for port in InheritedNode.Ports]
|
377
|
+
|
378
|
+
# THEN the ports are correct
|
379
|
+
# Potentially in the future, we support inheriting ports from multiple parents.
|
380
|
+
# For now, we take only the declared ports, so that not all nodes have the default port.
|
381
|
+
assert ports == ["bar"]
|
@@ -51,7 +51,9 @@ class APINode(BaseAPINode):
|
|
51
51
|
final_headers = {**headers, **header_overrides}
|
52
52
|
|
53
53
|
vellum_client_wrapper = self._context.vellum_client._client_wrapper
|
54
|
-
if self.url.startswith(vellum_client_wrapper._environment.default) and
|
54
|
+
if self.url.startswith(vellum_client_wrapper._environment.default) and (
|
55
|
+
"X-API-Key" not in final_headers and "X_API_KEY" not in final_headers
|
56
|
+
):
|
55
57
|
final_headers["X-API-Key"] = vellum_client_wrapper.api_key
|
56
58
|
|
57
59
|
return self._run(
|
@@ -158,3 +158,35 @@ def test_api_node__detects_client_environment_urls__does_not_override_headers(
|
|
158
158
|
# AND the vellum API should have been called with the correct headers
|
159
159
|
assert mock_response.last_request
|
160
160
|
assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-5678"
|
161
|
+
|
162
|
+
|
163
|
+
def test_api_node__detects_client_environment_urls__legacy_does_not_override_headers(
|
164
|
+
mock_httpx_transport, mock_requests, monkeypatch
|
165
|
+
):
|
166
|
+
# GIVEN an API node with a URL pointing back to Vellum
|
167
|
+
class SimpleAPINodeToVellum(APINode):
|
168
|
+
url = "https://api.vellum.ai"
|
169
|
+
headers = {
|
170
|
+
"X_API_KEY": "vellum-api-key-5678",
|
171
|
+
}
|
172
|
+
|
173
|
+
# AND a mock request sent to the Vellum API would return a 200
|
174
|
+
mock_response = mock_requests.get(
|
175
|
+
"https://api.vellum.ai",
|
176
|
+
status_code=200,
|
177
|
+
json={"data": [1, 2, 3]},
|
178
|
+
)
|
179
|
+
|
180
|
+
# AND an api key is set
|
181
|
+
monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
|
182
|
+
|
183
|
+
# WHEN we run the node
|
184
|
+
node = SimpleAPINodeToVellum()
|
185
|
+
node.run()
|
186
|
+
|
187
|
+
# THEN the execute_api method should not have been called
|
188
|
+
mock_httpx_transport.handle_request.assert_not_called()
|
189
|
+
|
190
|
+
# AND the vellum API should have been called with the correct headers
|
191
|
+
assert mock_response.last_request
|
192
|
+
assert mock_response.last_request.headers["X_API_KEY"] == "vellum-api-key-5678"
|
@@ -5,9 +5,11 @@ from vellum import AdHocExecutePromptEvent, ExecutePromptEvent, PromptOutput
|
|
5
5
|
from vellum.client.core.api_error import ApiError
|
6
6
|
from vellum.core import RequestOptions
|
7
7
|
from vellum.workflows.errors.types import WorkflowErrorCode, vellum_error_to_workflow_error
|
8
|
+
from vellum.workflows.events.node import NodeExecutionStreamingEvent
|
8
9
|
from vellum.workflows.exceptions import NodeException
|
9
10
|
from vellum.workflows.nodes.bases import BaseNode
|
10
11
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
12
|
+
from vellum.workflows.references.output import OutputReference
|
11
13
|
from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
|
12
14
|
from vellum.workflows.types.generics import StateType
|
13
15
|
|
@@ -85,3 +87,29 @@ class BasePromptNode(BaseNode, Generic[StateType]):
|
|
85
87
|
message="Failed to execute Prompt",
|
86
88
|
code=WorkflowErrorCode.INTERNAL_ERROR,
|
87
89
|
) from e
|
90
|
+
|
91
|
+
def __directly_emit_workflow_output__(
|
92
|
+
self,
|
93
|
+
event: NodeExecutionStreamingEvent,
|
94
|
+
workflow_output_descriptor: OutputReference,
|
95
|
+
) -> bool:
|
96
|
+
if event.output.name != "results":
|
97
|
+
return False
|
98
|
+
|
99
|
+
if not isinstance(event.output.delta, str) and not event.output.is_initiated:
|
100
|
+
return False
|
101
|
+
|
102
|
+
target_nodes = [e.to_node for e in self.Ports.default.edges if e.to_node.__simulates_workflow_output__]
|
103
|
+
target_node_output = next(
|
104
|
+
(
|
105
|
+
o
|
106
|
+
for target_node in target_nodes
|
107
|
+
for o in target_node.Outputs
|
108
|
+
if o == workflow_output_descriptor.instance
|
109
|
+
),
|
110
|
+
None,
|
111
|
+
)
|
112
|
+
if not target_node_output:
|
113
|
+
return False
|
114
|
+
|
115
|
+
return True
|
vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py
CHANGED
@@ -192,7 +192,30 @@ def test_validation_with_extra_variables(vellum_adhoc_prompt_client):
|
|
192
192
|
]
|
193
193
|
|
194
194
|
|
195
|
-
|
195
|
+
@pytest.mark.parametrize(
|
196
|
+
"custom_parameters,test_description",
|
197
|
+
[
|
198
|
+
(
|
199
|
+
{
|
200
|
+
"json_mode": False,
|
201
|
+
"json_schema": {
|
202
|
+
"name": "get_result",
|
203
|
+
"schema": {
|
204
|
+
"type": "object",
|
205
|
+
"required": ["result"],
|
206
|
+
"properties": {"result": {"type": "string", "description": ""}},
|
207
|
+
},
|
208
|
+
},
|
209
|
+
},
|
210
|
+
"with json_schema configured",
|
211
|
+
),
|
212
|
+
(
|
213
|
+
{},
|
214
|
+
"without json_mode or json_schema configured",
|
215
|
+
),
|
216
|
+
],
|
217
|
+
)
|
218
|
+
def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client, custom_parameters, test_description):
|
196
219
|
"""Confirm that InlinePromptNodes output the expected JSON when run."""
|
197
220
|
|
198
221
|
# GIVEN a node that subclasses InlinePromptNode
|
@@ -214,17 +237,7 @@ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
|
|
214
237
|
frequency_penalty=0.0,
|
215
238
|
presence_penalty=0.0,
|
216
239
|
logit_bias=None,
|
217
|
-
custom_parameters=
|
218
|
-
"json_mode": False,
|
219
|
-
"json_schema": {
|
220
|
-
"name": "get_result",
|
221
|
-
"schema": {
|
222
|
-
"type": "object",
|
223
|
-
"required": ["result"],
|
224
|
-
"properties": {"result": {"type": "string", "description": ""}},
|
225
|
-
},
|
226
|
-
},
|
227
|
-
},
|
240
|
+
custom_parameters=custom_parameters,
|
228
241
|
)
|
229
242
|
|
230
243
|
# AND a known JSON response from invoking an inline prompt
|
@@ -284,17 +297,7 @@ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
|
|
284
297
|
frequency_penalty=0.0,
|
285
298
|
presence_penalty=0.0,
|
286
299
|
logit_bias=None,
|
287
|
-
custom_parameters=
|
288
|
-
"json_mode": False,
|
289
|
-
"json_schema": {
|
290
|
-
"name": "get_result",
|
291
|
-
"schema": {
|
292
|
-
"type": "object",
|
293
|
-
"required": ["result"],
|
294
|
-
"properties": {"result": {"type": "string", "description": ""}},
|
295
|
-
},
|
296
|
-
},
|
297
|
-
},
|
300
|
+
custom_parameters=custom_parameters,
|
298
301
|
),
|
299
302
|
request_options=mock.ANY,
|
300
303
|
settings=None,
|
@@ -2,7 +2,6 @@ from typing import Set
|
|
2
2
|
|
3
3
|
from vellum.workflows.nodes.bases import BaseNode
|
4
4
|
from vellum.workflows.outputs.base import BaseOutputs
|
5
|
-
from vellum.workflows.ports.node_ports import NodePorts
|
6
5
|
from vellum.workflows.ports.port import Port
|
7
6
|
from vellum.workflows.ports.utils import validate_ports
|
8
7
|
from vellum.workflows.state.base import BaseState
|
@@ -18,7 +17,7 @@ class ConditionalNode(BaseNode):
|
|
18
17
|
class Trigger(BaseNode.Trigger):
|
19
18
|
merge_behavior = MergeBehavior.AWAIT_ANY
|
20
19
|
|
21
|
-
class Ports(
|
20
|
+
class Ports(BaseNode.Ports):
|
22
21
|
def __call__(self, outputs: BaseOutputs, state: BaseState) -> Set[Port]:
|
23
22
|
all_ports = [port for port in self.__class__]
|
24
23
|
enforce_single_invoked_port = validate_ports(all_ports)
|
@@ -48,26 +48,16 @@ class InlinePromptNode(BaseInlinePromptNode[StateType]):
|
|
48
48
|
string_outputs = []
|
49
49
|
json_output = None
|
50
50
|
|
51
|
-
should_parse_json = False
|
52
|
-
if hasattr(self, "parameters"):
|
53
|
-
custom_params = self.parameters.custom_parameters
|
54
|
-
if custom_params and isinstance(custom_params, dict):
|
55
|
-
json_schema = custom_params.get("json_schema", {})
|
56
|
-
if (isinstance(json_schema, dict) and "schema" in json_schema) or custom_params.get("json_mode", {}):
|
57
|
-
should_parse_json = True
|
58
|
-
|
59
51
|
for output in outputs:
|
60
52
|
if output.value is None:
|
61
53
|
continue
|
62
54
|
|
63
55
|
if output.type == "STRING":
|
64
56
|
string_outputs.append(output.value)
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
except (json.JSONDecodeError, TypeError):
|
70
|
-
pass
|
57
|
+
try:
|
58
|
+
json_output = json.loads(output.value)
|
59
|
+
except (json.JSONDecodeError, TypeError):
|
60
|
+
pass
|
71
61
|
elif output.type == "JSON":
|
72
62
|
string_outputs.append(json.dumps(output.value, indent=4))
|
73
63
|
json_output = output.value
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import json
|
2
2
|
from typing import ClassVar
|
3
3
|
|
4
|
+
from vellum.workflows.errors import WorkflowErrorCode
|
5
|
+
from vellum.workflows.exceptions import NodeException
|
4
6
|
from vellum.workflows.nodes.displayable.bases import BaseSearchNode as BaseSearchNode
|
5
7
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
6
8
|
from vellum.workflows.types import MergeBehavior
|
@@ -35,6 +37,12 @@ class SearchNode(BaseSearchNode[StateType]):
|
|
35
37
|
text: str
|
36
38
|
|
37
39
|
def run(self) -> Outputs:
|
40
|
+
if self.query is None or self.query == "":
|
41
|
+
raise NodeException(
|
42
|
+
message="Search query is required but was not provided",
|
43
|
+
code=WorkflowErrorCode.INVALID_INPUTS,
|
44
|
+
)
|
45
|
+
|
38
46
|
if not isinstance(self.query, str):
|
39
47
|
self.query = json.dumps(self.query, cls=DefaultStateEncoder)
|
40
48
|
|