vellum-ai 0.14.65__py3-none-any.whl → 0.14.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. vellum/client/README.md +1 -1
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/reference.md +2767 -0
  4. vellum/client/types/document_read.py +0 -1
  5. vellum/client/types/folder_entity_prompt_sandbox_data.py +1 -0
  6. vellum/client/types/folder_entity_workflow_sandbox_data.py +1 -0
  7. vellum/workflows/expressions/accessor.py +22 -5
  8. vellum/workflows/expressions/tests/test_accessor.py +189 -0
  9. vellum/workflows/nodes/bases/base.py +30 -39
  10. vellum/workflows/nodes/bases/tests/test_base_node.py +48 -2
  11. vellum/workflows/nodes/displayable/api_node/node.py +3 -1
  12. vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +32 -0
  13. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +28 -0
  14. vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +26 -23
  15. vellum/workflows/nodes/displayable/conditional_node/node.py +1 -2
  16. vellum/workflows/nodes/displayable/final_output_node/node.py +2 -0
  17. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +4 -14
  18. vellum/workflows/nodes/displayable/search_node/node.py +8 -0
  19. vellum/workflows/nodes/displayable/search_node/tests/test_node.py +19 -0
  20. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +4 -13
  21. vellum/workflows/runner/runner.py +13 -17
  22. vellum/workflows/state/base.py +0 -4
  23. {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/METADATA +2 -2
  24. {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/RECORD +33 -30
  25. vellum_cli/image_push.py +62 -7
  26. vellum_cli/pull.py +38 -9
  27. vellum_cli/tests/test_image_push_error_handling.py +184 -0
  28. vellum_cli/tests/test_pull.py +12 -9
  29. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_inline_workflow_serialization.py +661 -0
  30. vellum_ee/workflows/display/utils/expressions.py +17 -0
  31. {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/LICENSE +0 -0
  32. {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/WHEEL +0 -0
  33. {vellum_ai-0.14.65.dist-info → vellum_ai-0.14.67.dist-info}/entry_points.txt +0 -0
@@ -32,7 +32,6 @@ class DocumentRead(UniversalBaseModel):
32
32
  """
33
33
 
34
34
  original_file_url: typing.Optional[str] = None
35
- processed_file_url: typing.Optional[str] = None
36
35
  document_to_document_indexes: typing.List[DocumentDocumentToDocumentIndex]
37
36
  metadata: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = pydantic.Field(default=None)
38
37
  """
@@ -14,6 +14,7 @@ class FolderEntityPromptSandboxData(UniversalBaseModel):
14
14
  created: dt.datetime
15
15
  modified: dt.datetime
16
16
  status: EntityStatus
17
+ description: typing.Optional[str] = None
17
18
  last_deployed_on: typing.Optional[dt.datetime] = None
18
19
 
19
20
  if IS_PYDANTIC_V2:
@@ -14,6 +14,7 @@ class FolderEntityWorkflowSandboxData(UniversalBaseModel):
14
14
  created: dt.datetime
15
15
  modified: dt.datetime
16
16
  status: EntityStatus
17
+ description: typing.Optional[str] = None
17
18
  last_deployed_on: typing.Optional[dt.datetime] = None
18
19
 
19
20
  if IS_PYDANTIC_V2:
@@ -35,20 +35,37 @@ class AccessorExpression(BaseDescriptor[Any]):
35
35
  if isinstance(self._field, int):
36
36
  raise InvalidExpressionException("Cannot access field by index on a dataclass")
37
37
 
38
- return getattr(base, self._field)
38
+ try:
39
+ return getattr(base, self._field)
40
+ except AttributeError:
41
+ raise InvalidExpressionException(f"Field '{self._field}' not found on dataclass {type(base).__name__}")
39
42
 
40
43
  if isinstance(base, BaseModel):
41
44
  if isinstance(self._field, int):
42
45
  raise InvalidExpressionException("Cannot access field by index on a BaseModel")
43
46
 
44
- return getattr(base, self._field)
47
+ try:
48
+ return getattr(base, self._field)
49
+ except AttributeError:
50
+ raise InvalidExpressionException(f"Field '{self._field}' not found on BaseModel {type(base).__name__}")
45
51
 
46
52
  if isinstance(base, Mapping):
47
- return base[self._field]
53
+ try:
54
+ return base[self._field]
55
+ except KeyError:
56
+ raise InvalidExpressionException(f"Key '{self._field}' not found in mapping")
48
57
 
49
58
  if isinstance(base, Sequence):
50
- index = int(self._field)
51
- return base[index]
59
+ try:
60
+ index = int(self._field)
61
+ return base[index]
62
+ except (IndexError, ValueError):
63
+ if isinstance(self._field, int) or (isinstance(self._field, str) and self._field.lstrip("-").isdigit()):
64
+ raise InvalidExpressionException(
65
+ f"Index {self._field} is out of bounds for sequence of length {len(base)}"
66
+ )
67
+ else:
68
+ raise InvalidExpressionException(f"Invalid index '{self._field}' for sequence access")
52
69
 
53
70
  raise InvalidExpressionException(f"Cannot get field {self._field} from {base}")
54
71
 
@@ -0,0 +1,189 @@
1
+ import pytest
2
+ from dataclasses import dataclass
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from vellum.workflows.descriptors.exceptions import InvalidExpressionException
7
+ from vellum.workflows.expressions.accessor import AccessorExpression
8
+ from vellum.workflows.references.constant import ConstantValueReference
9
+ from vellum.workflows.state.base import BaseState
10
+
11
+
12
+ @dataclass
13
+ class TestDataclass:
14
+ name: str
15
+ value: int
16
+
17
+
18
+ class TestBaseModel(BaseModel):
19
+ name: str
20
+ value: int
21
+
22
+
23
+ class TestState(BaseState):
24
+ pass
25
+
26
+
27
+ def test_accessor_expression_dict_valid_key():
28
+ state = TestState()
29
+ base_ref = ConstantValueReference({"name": "test", "value": 42})
30
+ accessor = AccessorExpression(base=base_ref, field="name")
31
+
32
+ result = accessor.resolve(state)
33
+
34
+ assert result == "test"
35
+
36
+
37
+ def test_accessor_expression_dict_invalid_key():
38
+ state = TestState()
39
+ base_ref = ConstantValueReference({"name": "test", "value": 42})
40
+ accessor = AccessorExpression(base=base_ref, field="missing_key")
41
+
42
+ with pytest.raises(InvalidExpressionException) as exc_info:
43
+ accessor.resolve(state)
44
+
45
+ assert "Key 'missing_key' not found in mapping" in str(exc_info.value)
46
+
47
+
48
+ def test_accessor_expression_list_valid_index():
49
+ state = TestState()
50
+ base_ref = ConstantValueReference(["first", "second", "third"])
51
+ accessor = AccessorExpression(base=base_ref, field=1)
52
+
53
+ result = accessor.resolve(state)
54
+
55
+ assert result == "second"
56
+
57
+
58
+ def test_accessor_expression_list_invalid_index():
59
+ state = TestState()
60
+ base_ref = ConstantValueReference(["first", "second"])
61
+ accessor = AccessorExpression(base=base_ref, field=5)
62
+
63
+ with pytest.raises(InvalidExpressionException) as exc_info:
64
+ accessor.resolve(state)
65
+
66
+ assert "Index 5 is out of bounds for sequence of length 2" in str(exc_info.value)
67
+
68
+
69
+ def test_accessor_expression_list_negative_index():
70
+ state = TestState()
71
+ base_ref = ConstantValueReference(["first", "second", "third"])
72
+ accessor = AccessorExpression(base=base_ref, field=-1)
73
+
74
+ result = accessor.resolve(state)
75
+
76
+ assert result == "third"
77
+
78
+
79
+ def test_accessor_expression_list_invalid_negative_index():
80
+ state = TestState()
81
+ base_ref = ConstantValueReference(["first", "second"])
82
+ accessor = AccessorExpression(base=base_ref, field=-5)
83
+
84
+ with pytest.raises(InvalidExpressionException) as exc_info:
85
+ accessor.resolve(state)
86
+
87
+ assert "Index -5 is out of bounds for sequence of length 2" in str(exc_info.value)
88
+
89
+
90
+ def test_accessor_expression_list_string_index():
91
+ state = TestState()
92
+ base_ref = ConstantValueReference(["first", "second", "third"])
93
+ accessor = AccessorExpression(base=base_ref, field="1")
94
+
95
+ result = accessor.resolve(state)
96
+
97
+ assert result == "second"
98
+
99
+
100
+ def test_accessor_expression_list_invalid_string_index():
101
+ state = TestState()
102
+ base_ref = ConstantValueReference(["first", "second"])
103
+ accessor = AccessorExpression(base=base_ref, field="invalid")
104
+
105
+ with pytest.raises(InvalidExpressionException) as exc_info:
106
+ accessor.resolve(state)
107
+
108
+ assert "Invalid index 'invalid' for sequence access" in str(exc_info.value)
109
+
110
+
111
+ def test_accessor_expression_dataclass_valid_field():
112
+ state = TestState()
113
+ test_obj = TestDataclass(name="test", value=42)
114
+ base_ref = ConstantValueReference(test_obj)
115
+ accessor = AccessorExpression(base=base_ref, field="name")
116
+
117
+ result = accessor.resolve(state)
118
+
119
+ assert result == "test"
120
+
121
+
122
+ def test_accessor_expression_dataclass_invalid_field():
123
+ state = TestState()
124
+ test_obj = TestDataclass(name="test", value=42)
125
+ base_ref = ConstantValueReference(test_obj)
126
+ accessor = AccessorExpression(base=base_ref, field="missing_field")
127
+
128
+ with pytest.raises(InvalidExpressionException) as exc_info:
129
+ accessor.resolve(state)
130
+
131
+ assert "Field 'missing_field' not found on dataclass TestDataclass" in str(exc_info.value)
132
+
133
+
134
+ def test_accessor_expression_basemodel_valid_field():
135
+ state = TestState()
136
+ test_obj = TestBaseModel(name="test", value=42)
137
+ base_ref = ConstantValueReference(test_obj)
138
+ accessor = AccessorExpression(base=base_ref, field="name")
139
+
140
+ result = accessor.resolve(state)
141
+
142
+ assert result == "test"
143
+
144
+
145
+ def test_accessor_expression_basemodel_invalid_field():
146
+ state = TestState()
147
+ test_obj = TestBaseModel(name="test", value=42)
148
+ base_ref = ConstantValueReference(test_obj)
149
+ accessor = AccessorExpression(base=base_ref, field="missing_field")
150
+
151
+ with pytest.raises(InvalidExpressionException) as exc_info:
152
+ accessor.resolve(state)
153
+
154
+ assert "Field 'missing_field' not found on BaseModel TestBaseModel" in str(exc_info.value)
155
+
156
+
157
+ def test_accessor_expression_dataclass_index_access():
158
+ state = TestState()
159
+ test_obj = TestDataclass(name="test", value=42)
160
+ base_ref = ConstantValueReference(test_obj)
161
+ accessor = AccessorExpression(base=base_ref, field=0)
162
+
163
+ with pytest.raises(InvalidExpressionException) as exc_info:
164
+ accessor.resolve(state)
165
+
166
+ assert "Cannot access field by index on a dataclass" in str(exc_info.value)
167
+
168
+
169
+ def test_accessor_expression_basemodel_index_access():
170
+ state = TestState()
171
+ test_obj = TestBaseModel(name="test", value=42)
172
+ base_ref = ConstantValueReference(test_obj)
173
+ accessor = AccessorExpression(base=base_ref, field=0)
174
+
175
+ with pytest.raises(InvalidExpressionException) as exc_info:
176
+ accessor.resolve(state)
177
+
178
+ assert "Cannot access field by index on a BaseModel" in str(exc_info.value)
179
+
180
+
181
+ def test_accessor_expression_unsupported_type():
182
+ state = TestState()
183
+ base_ref = ConstantValueReference(42)
184
+ accessor = AccessorExpression(base=base_ref, field="field")
185
+
186
+ with pytest.raises(InvalidExpressionException) as exc_info:
187
+ accessor.resolve(state)
188
+
189
+ assert "Cannot get field field from 42" in str(exc_info.value)
@@ -10,6 +10,7 @@ from vellum.workflows.constants import undefined
10
10
  from vellum.workflows.descriptors.base import BaseDescriptor
11
11
  from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
12
12
  from vellum.workflows.errors.types import WorkflowErrorCode
13
+ from vellum.workflows.events.node import NodeExecutionStreamingEvent
13
14
  from vellum.workflows.exceptions import NodeException
14
15
  from vellum.workflows.graph import Graph
15
16
  from vellum.workflows.graph.graph import GraphTarget
@@ -92,7 +93,7 @@ class BaseNodeMeta(ABCMeta):
92
93
  if issubclass(base, BaseNode):
93
94
  ports_dct = {p.name: Port(default=p.default) for p in base.Ports}
94
95
  ports_dct["__module__"] = dct["__module__"]
95
- dct["Ports"] = type(f"{name}.Ports", (NodePorts,), ports_dct)
96
+ dct["Ports"] = type(f"{name}.Ports", (base.Ports,), ports_dct)
96
97
  break
97
98
 
98
99
  if "Execution" not in dct:
@@ -357,41 +358,15 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
357
358
  state.meta.node_execution_cache._dependencies_invoked[execution_id].add(invoked_by)
358
359
  return execution_id
359
360
 
360
- # For AWAIT_ANY in workflows, we have two cases:
361
- # 1. The node is being re-executed because of a fork
362
- # 2. The node is being re-executed because of a loop
363
- # For case 1, we need to remove the fork id from the node_to_fork_id mapping
364
- # For case 2, we need to check if the node is in a loop
361
+ # For AWAIT_ANY in workflows, we need to detect if the node is in a loop
362
+ # If the node is in a loop, we can trigger the node again
365
363
  in_loop = False
366
- # Default to true, will be set to false if the merged node has already been triggered
367
- should_retrigger = True
368
364
  if cls.merge_behavior == MergeBehavior.AWAIT_ANY:
369
- # Get the node that triggered the current execution
370
- invoked_by_node = state.meta.node_execution_cache.__node_execution_lookup__.get(invoked_by)
371
-
372
- # Check if invoked by node is a forked node
373
- if invoked_by_node is not None:
374
- fork_id = state.meta.node_execution_cache.__node_to_fork_id__.get(invoked_by_node, None)
375
- if fork_id:
376
- # If the invoked by node has a fork id and that fork id is in __all_fork_ids__
377
- # We will
378
- # 1. remove the fork id from __all_fork_ids__
379
- # 2. remove the fork id from the __node_to_fork_id__ mapping
380
- # else (this mean the fork has already been triggered)
381
- # remove the id from the node_to_fork_id mapping and not retrigger again
382
- all_fork_ids = state.meta.node_execution_cache.__all_fork_ids__
383
- if fork_id in all_fork_ids:
384
- # When the next forked node merge, it will not trigger the node again
385
- # We should consider adding a lock here to prevent race condition
386
- all_fork_ids.remove(fork_id)
387
- state.meta.node_execution_cache.__node_to_fork_id__.pop(invoked_by_node, None)
388
- else:
389
- should_retrigger = False
390
- state.meta.node_execution_cache.__node_to_fork_id__.pop(invoked_by_node, None)
391
-
392
- # If should_retrigger is false, then we will not trigger the node already
393
- # So we don't need to check loop behavior
394
- if should_retrigger:
365
+ # Get the latest fulfilled execution ID of current node
366
+ fulfilled_stack = state.meta.node_execution_cache._node_executions_fulfilled[cls.node_class]
367
+ current_latest_fulfilled_id = fulfilled_stack.peek() if not fulfilled_stack.is_empty() else None
368
+ # If the current node has not been fulfilled yet, we don't need to check for loop
369
+ if current_latest_fulfilled_id is not None:
395
370
  # Trace back through the dependency chain to detect if this node triggers itself
396
371
  visited = set()
397
372
  current_execution_id = invoked_by
@@ -417,15 +392,18 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
417
392
  current_execution_id
418
393
  )
419
394
 
420
- # If we've found our target node class in the chain, we're in a loop
395
+ # If we've found our target node class in the chain
421
396
  if current_node_class == cls.node_class:
422
- in_loop = True
397
+ # Check if the execution id is the same as the latest fulfilled execution id
398
+ # If yes, we're in a loop
399
+ if current_execution_id == current_latest_fulfilled_id:
400
+ in_loop = True
401
+ # If not, current node has been triggered by other node,
402
+ # we can break out of the loop
423
403
  break
424
404
 
425
405
  for queued_node_execution_id in state.meta.node_execution_cache._node_executions_queued[cls.node_class]:
426
- # When should_retrigger is false, it means the merged node has already been triggered
427
- # So we don't need to trigger the node again
428
- if not should_retrigger or (
406
+ if (
429
407
  invoked_by not in state.meta.node_execution_cache._dependencies_invoked[queued_node_execution_id]
430
408
  and not in_loop
431
409
  ):
@@ -509,3 +487,16 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
509
487
 
510
488
  def __repr__(self) -> str:
511
489
  return str(self.__class__)
490
+
491
+ __simulates_workflow_output__ = False
492
+
493
+ def __directly_emit_workflow_output__(
494
+ self, event: NodeExecutionStreamingEvent, workflow_output_descriptor: OutputReference
495
+ ) -> bool:
496
+ """
497
+ In the legacy workflow runner, there was support for emitting streaming workflow outputs for prompt nodes
498
+ connected to terminal nodes. These two private methods provides a hacky, intentionally short-lived workaround
499
+ for us to enable this until we can directly reference prompt outputs from the UI.
500
+ """
501
+
502
+ return False
@@ -1,6 +1,6 @@
1
1
  import pytest
2
2
  from uuid import UUID
3
- from typing import Optional
3
+ from typing import Optional, Set
4
4
 
5
5
  from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
6
6
  from vellum.core.pydantic_utilities import UniversalBaseModel
@@ -9,7 +9,9 @@ from vellum.workflows.descriptors.tests.test_utils import FixtureState
9
9
  from vellum.workflows.inputs.base import BaseInputs
10
10
  from vellum.workflows.nodes import FinalOutputNode
11
11
  from vellum.workflows.nodes.bases.base import BaseNode
12
- from vellum.workflows.outputs.base import BaseOutputs
12
+ from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
13
+ from vellum.workflows.ports.port import Port
14
+ from vellum.workflows.references.constant import ConstantValueReference
13
15
  from vellum.workflows.references.node import NodeReference
14
16
  from vellum.workflows.references.output import OutputReference
15
17
  from vellum.workflows.state.base import BaseState, StateMeta
@@ -333,3 +335,47 @@ def test_base_node__node_reference_of_inherited_annotation():
333
335
  # THEN the node reference is of the correct type
334
336
  assert isinstance(node_reference, NodeReference)
335
337
  assert node_reference.name == "foo"
338
+
339
+
340
+ def test_base_node__ports_inheritance():
341
+ # GIVEN a node with one port
342
+ class MyNode(BaseNode):
343
+ class Ports(BaseNode.Ports):
344
+ foo = Port.on_if(ConstantValueReference(True))
345
+
346
+ def __lt__(self, output: BaseOutput) -> Set[Port]:
347
+ return {self.foo}
348
+
349
+ # AND a node that inherits from MyNode
350
+ class InheritedNode(MyNode):
351
+ pass
352
+
353
+ # WHEN we collect the ports
354
+ ports = [port.name for port in InheritedNode.Ports]
355
+
356
+ # THEN the inheritance is correct
357
+ inherited_ports = InheritedNode.Ports()
358
+ assert inherited_ports.__lt__(BaseOutput(name="foo")) == {InheritedNode.Ports.foo}
359
+
360
+ # AND the ports are correct
361
+ assert ports == ["foo"]
362
+
363
+
364
+ def test_base_node__ports_inheritance__cumulative_ports():
365
+ # GIVEN a node with one port
366
+ class MyNode(BaseNode):
367
+ class Ports(BaseNode.Ports):
368
+ foo = Port.on_if(ConstantValueReference(True))
369
+
370
+ # AND a node that inherits from MyNode with another port
371
+ class InheritedNode(MyNode):
372
+ class Ports(MyNode.Ports):
373
+ bar = Port.on_if(ConstantValueReference(True))
374
+
375
+ # WHEN we collect the ports
376
+ ports = [port.name for port in InheritedNode.Ports]
377
+
378
+ # THEN the ports are correct
379
+ # Potentially in the future, we support inheriting ports from multiple parents.
380
+ # For now, we take only the declared ports, so that not all nodes have the default port.
381
+ assert ports == ["bar"]
@@ -51,7 +51,9 @@ class APINode(BaseAPINode):
51
51
  final_headers = {**headers, **header_overrides}
52
52
 
53
53
  vellum_client_wrapper = self._context.vellum_client._client_wrapper
54
- if self.url.startswith(vellum_client_wrapper._environment.default) and "X-API-Key" not in final_headers:
54
+ if self.url.startswith(vellum_client_wrapper._environment.default) and (
55
+ "X-API-Key" not in final_headers and "X_API_KEY" not in final_headers
56
+ ):
55
57
  final_headers["X-API-Key"] = vellum_client_wrapper.api_key
56
58
 
57
59
  return self._run(
@@ -158,3 +158,35 @@ def test_api_node__detects_client_environment_urls__does_not_override_headers(
158
158
  # AND the vellum API should have been called with the correct headers
159
159
  assert mock_response.last_request
160
160
  assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-5678"
161
+
162
+
163
+ def test_api_node__detects_client_environment_urls__legacy_does_not_override_headers(
164
+ mock_httpx_transport, mock_requests, monkeypatch
165
+ ):
166
+ # GIVEN an API node with a URL pointing back to Vellum
167
+ class SimpleAPINodeToVellum(APINode):
168
+ url = "https://api.vellum.ai"
169
+ headers = {
170
+ "X_API_KEY": "vellum-api-key-5678",
171
+ }
172
+
173
+ # AND a mock request sent to the Vellum API would return a 200
174
+ mock_response = mock_requests.get(
175
+ "https://api.vellum.ai",
176
+ status_code=200,
177
+ json={"data": [1, 2, 3]},
178
+ )
179
+
180
+ # AND an api key is set
181
+ monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
182
+
183
+ # WHEN we run the node
184
+ node = SimpleAPINodeToVellum()
185
+ node.run()
186
+
187
+ # THEN the execute_api method should not have been called
188
+ mock_httpx_transport.handle_request.assert_not_called()
189
+
190
+ # AND the vellum API should have been called with the correct headers
191
+ assert mock_response.last_request
192
+ assert mock_response.last_request.headers["X_API_KEY"] == "vellum-api-key-5678"
@@ -5,9 +5,11 @@ from vellum import AdHocExecutePromptEvent, ExecutePromptEvent, PromptOutput
5
5
  from vellum.client.core.api_error import ApiError
6
6
  from vellum.core import RequestOptions
7
7
  from vellum.workflows.errors.types import WorkflowErrorCode, vellum_error_to_workflow_error
8
+ from vellum.workflows.events.node import NodeExecutionStreamingEvent
8
9
  from vellum.workflows.exceptions import NodeException
9
10
  from vellum.workflows.nodes.bases import BaseNode
10
11
  from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
12
+ from vellum.workflows.references.output import OutputReference
11
13
  from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
12
14
  from vellum.workflows.types.generics import StateType
13
15
 
@@ -85,3 +87,29 @@ class BasePromptNode(BaseNode, Generic[StateType]):
85
87
  message="Failed to execute Prompt",
86
88
  code=WorkflowErrorCode.INTERNAL_ERROR,
87
89
  ) from e
90
+
91
+ def __directly_emit_workflow_output__(
92
+ self,
93
+ event: NodeExecutionStreamingEvent,
94
+ workflow_output_descriptor: OutputReference,
95
+ ) -> bool:
96
+ if event.output.name != "results":
97
+ return False
98
+
99
+ if not isinstance(event.output.delta, str) and not event.output.is_initiated:
100
+ return False
101
+
102
+ target_nodes = [e.to_node for e in self.Ports.default.edges if e.to_node.__simulates_workflow_output__]
103
+ target_node_output = next(
104
+ (
105
+ o
106
+ for target_node in target_nodes
107
+ for o in target_node.Outputs
108
+ if o == workflow_output_descriptor.instance
109
+ ),
110
+ None,
111
+ )
112
+ if not target_node_output:
113
+ return False
114
+
115
+ return True
@@ -192,7 +192,30 @@ def test_validation_with_extra_variables(vellum_adhoc_prompt_client):
192
192
  ]
193
193
 
194
194
 
195
- def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
195
+ @pytest.mark.parametrize(
196
+ "custom_parameters,test_description",
197
+ [
198
+ (
199
+ {
200
+ "json_mode": False,
201
+ "json_schema": {
202
+ "name": "get_result",
203
+ "schema": {
204
+ "type": "object",
205
+ "required": ["result"],
206
+ "properties": {"result": {"type": "string", "description": ""}},
207
+ },
208
+ },
209
+ },
210
+ "with json_schema configured",
211
+ ),
212
+ (
213
+ {},
214
+ "without json_mode or json_schema configured",
215
+ ),
216
+ ],
217
+ )
218
+ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client, custom_parameters, test_description):
196
219
  """Confirm that InlinePromptNodes output the expected JSON when run."""
197
220
 
198
221
  # GIVEN a node that subclasses InlinePromptNode
@@ -214,17 +237,7 @@ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
214
237
  frequency_penalty=0.0,
215
238
  presence_penalty=0.0,
216
239
  logit_bias=None,
217
- custom_parameters={
218
- "json_mode": False,
219
- "json_schema": {
220
- "name": "get_result",
221
- "schema": {
222
- "type": "object",
223
- "required": ["result"],
224
- "properties": {"result": {"type": "string", "description": ""}},
225
- },
226
- },
227
- },
240
+ custom_parameters=custom_parameters,
228
241
  )
229
242
 
230
243
  # AND a known JSON response from invoking an inline prompt
@@ -284,17 +297,7 @@ def test_inline_prompt_node__json_output(vellum_adhoc_prompt_client):
284
297
  frequency_penalty=0.0,
285
298
  presence_penalty=0.0,
286
299
  logit_bias=None,
287
- custom_parameters={
288
- "json_mode": False,
289
- "json_schema": {
290
- "name": "get_result",
291
- "schema": {
292
- "type": "object",
293
- "required": ["result"],
294
- "properties": {"result": {"type": "string", "description": ""}},
295
- },
296
- },
297
- },
300
+ custom_parameters=custom_parameters,
298
301
  ),
299
302
  request_options=mock.ANY,
300
303
  settings=None,
@@ -2,7 +2,6 @@ from typing import Set
2
2
 
3
3
  from vellum.workflows.nodes.bases import BaseNode
4
4
  from vellum.workflows.outputs.base import BaseOutputs
5
- from vellum.workflows.ports.node_ports import NodePorts
6
5
  from vellum.workflows.ports.port import Port
7
6
  from vellum.workflows.ports.utils import validate_ports
8
7
  from vellum.workflows.state.base import BaseState
@@ -18,7 +17,7 @@ class ConditionalNode(BaseNode):
18
17
  class Trigger(BaseNode.Trigger):
19
18
  merge_behavior = MergeBehavior.AWAIT_ANY
20
19
 
21
- class Ports(NodePorts):
20
+ class Ports(BaseNode.Ports):
22
21
  def __call__(self, outputs: BaseOutputs, state: BaseState) -> Set[Port]:
23
22
  all_ports = [port for port in self.__class__]
24
23
  enforce_single_invoked_port = validate_ports(all_ports)
@@ -61,3 +61,5 @@ class FinalOutputNode(BaseNode[StateType], Generic[StateType, _OutputType], meta
61
61
  self.__class__.get_output_type(),
62
62
  )
63
63
  )
64
+
65
+ __simulates_workflow_output__ = True
@@ -48,26 +48,16 @@ class InlinePromptNode(BaseInlinePromptNode[StateType]):
48
48
  string_outputs = []
49
49
  json_output = None
50
50
 
51
- should_parse_json = False
52
- if hasattr(self, "parameters"):
53
- custom_params = self.parameters.custom_parameters
54
- if custom_params and isinstance(custom_params, dict):
55
- json_schema = custom_params.get("json_schema", {})
56
- if (isinstance(json_schema, dict) and "schema" in json_schema) or custom_params.get("json_mode", {}):
57
- should_parse_json = True
58
-
59
51
  for output in outputs:
60
52
  if output.value is None:
61
53
  continue
62
54
 
63
55
  if output.type == "STRING":
64
56
  string_outputs.append(output.value)
65
- if should_parse_json:
66
- try:
67
- parsed_json = json.loads(output.value)
68
- json_output = parsed_json
69
- except (json.JSONDecodeError, TypeError):
70
- pass
57
+ try:
58
+ json_output = json.loads(output.value)
59
+ except (json.JSONDecodeError, TypeError):
60
+ pass
71
61
  elif output.type == "JSON":
72
62
  string_outputs.append(json.dumps(output.value, indent=4))
73
63
  json_output = output.value
@@ -1,6 +1,8 @@
1
1
  import json
2
2
  from typing import ClassVar
3
3
 
4
+ from vellum.workflows.errors import WorkflowErrorCode
5
+ from vellum.workflows.exceptions import NodeException
4
6
  from vellum.workflows.nodes.displayable.bases import BaseSearchNode as BaseSearchNode
5
7
  from vellum.workflows.state.encoder import DefaultStateEncoder
6
8
  from vellum.workflows.types import MergeBehavior
@@ -35,6 +37,12 @@ class SearchNode(BaseSearchNode[StateType]):
35
37
  text: str
36
38
 
37
39
  def run(self) -> Outputs:
40
+ if self.query is None or self.query == "":
41
+ raise NodeException(
42
+ message="Search query is required but was not provided",
43
+ code=WorkflowErrorCode.INVALID_INPUTS,
44
+ )
45
+
38
46
  if not isinstance(self.query, str):
39
47
  self.query = json.dumps(self.query, cls=DefaultStateEncoder)
40
48