vellum-ai 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/core/pydantic_utilities.py +5 -0
- vellum/client/resources/workflows/client.py +8 -0
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/base.py +1 -1
- vellum/workflows/descriptors/tests/test_utils.py +3 -0
- vellum/workflows/expressions/accessor.py +8 -2
- vellum/workflows/nodes/core/map_node/node.py +49 -24
- vellum/workflows/nodes/core/map_node/tests/test_node.py +4 -4
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +3 -0
- vellum/workflows/nodes/displayable/bases/search_node.py +37 -2
- vellum/workflows/nodes/displayable/bases/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/bases/tests/test_utils.py +61 -0
- vellum/workflows/nodes/displayable/bases/types.py +42 -0
- vellum/workflows/nodes/displayable/bases/utils.py +112 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +0 -1
- vellum/workflows/nodes/displayable/search_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/search_node/tests/test_node.py +164 -0
- vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +2 -3
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +0 -1
- vellum/workflows/runner/runner.py +37 -4
- vellum/workflows/types/tests/test_utils.py +5 -2
- vellum/workflows/types/utils.py +4 -0
- vellum/workflows/workflows/base.py +14 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/RECORD +46 -36
- vellum_cli/__init__.py +10 -0
- vellum_cli/ping.py +28 -0
- vellum_cli/tests/test_ping.py +47 -0
- vellum_ee/workflows/display/nodes/vellum/base_node.py +22 -9
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -0
- vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +14 -10
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +8 -1
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +67 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +66 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +660 -0
- vellum_ee/workflows/display/utils/vellum.py +4 -42
- vellum_ee/workflows/display/vellum.py +7 -36
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +2 -1
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.13.
|
21
|
+
"X-Fern-SDK-Version": "0.13.1",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -133,6 +133,11 @@ class UniversalBaseModel(pydantic.BaseModel):
|
|
133
133
|
#
|
134
134
|
# We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models
|
135
135
|
# that we have less control over, and this is less intrusive than custom serializers for now.
|
136
|
+
kwargs = {
|
137
|
+
**kwargs,
|
138
|
+
"warnings": False,
|
139
|
+
}
|
140
|
+
|
136
141
|
if IS_PYDANTIC_V2:
|
137
142
|
kwargs_with_defaults_exclude_unset: typing.Any = {
|
138
143
|
**kwargs,
|
@@ -101,6 +101,7 @@ class WorkflowsClient:
|
|
101
101
|
deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
|
102
102
|
artifact: typing.Optional[core.File] = OMIT,
|
103
103
|
dry_run: typing.Optional[bool] = OMIT,
|
104
|
+
strict: typing.Optional[bool] = OMIT,
|
104
105
|
request_options: typing.Optional[RequestOptions] = None,
|
105
106
|
) -> WorkflowPushResponse:
|
106
107
|
"""
|
@@ -120,6 +121,8 @@ class WorkflowsClient:
|
|
120
121
|
|
121
122
|
dry_run : typing.Optional[bool]
|
122
123
|
|
124
|
+
strict : typing.Optional[bool]
|
125
|
+
|
123
126
|
request_options : typing.Optional[RequestOptions]
|
124
127
|
Request-specific configuration.
|
125
128
|
|
@@ -150,6 +153,7 @@ class WorkflowsClient:
|
|
150
153
|
"workflow_sandbox_id": workflow_sandbox_id,
|
151
154
|
"deployment_config": deployment_config,
|
152
155
|
"dry_run": dry_run,
|
156
|
+
"strict": strict,
|
153
157
|
},
|
154
158
|
files={
|
155
159
|
"artifact": artifact,
|
@@ -254,6 +258,7 @@ class AsyncWorkflowsClient:
|
|
254
258
|
deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
|
255
259
|
artifact: typing.Optional[core.File] = OMIT,
|
256
260
|
dry_run: typing.Optional[bool] = OMIT,
|
261
|
+
strict: typing.Optional[bool] = OMIT,
|
257
262
|
request_options: typing.Optional[RequestOptions] = None,
|
258
263
|
) -> WorkflowPushResponse:
|
259
264
|
"""
|
@@ -273,6 +278,8 @@ class AsyncWorkflowsClient:
|
|
273
278
|
|
274
279
|
dry_run : typing.Optional[bool]
|
275
280
|
|
281
|
+
strict : typing.Optional[bool]
|
282
|
+
|
276
283
|
request_options : typing.Optional[RequestOptions]
|
277
284
|
Request-specific configuration.
|
278
285
|
|
@@ -311,6 +318,7 @@ class AsyncWorkflowsClient:
|
|
311
318
|
"workflow_sandbox_id": workflow_sandbox_id,
|
312
319
|
"deployment_config": deployment_config,
|
313
320
|
"dry_run": dry_run,
|
321
|
+
"strict": strict,
|
314
322
|
},
|
315
323
|
files={
|
316
324
|
"artifact": artifact,
|
@@ -121,7 +121,7 @@ class BaseDescriptor(Generic[_T]):
|
|
121
121
|
|
122
122
|
return CoalesceExpression(lhs=self, rhs=other)
|
123
123
|
|
124
|
-
def __getitem__(self, field: str) -> "AccessorExpression":
|
124
|
+
def __getitem__(self, field: Union[str, int]) -> "AccessorExpression":
|
125
125
|
from vellum.workflows.expressions.accessor import AccessorExpression
|
126
126
|
|
127
127
|
return AccessorExpression(base=self, field=field)
|
@@ -19,6 +19,7 @@ class FixtureState(BaseState):
|
|
19
19
|
}
|
20
20
|
|
21
21
|
eta = None
|
22
|
+
theta = ["baz"]
|
22
23
|
|
23
24
|
|
24
25
|
class DummyNode(BaseNode[FixtureState]):
|
@@ -75,6 +76,7 @@ class DummyNode(BaseNode[FixtureState]):
|
|
75
76
|
),
|
76
77
|
(FixtureState.zeta["foo"], "bar"),
|
77
78
|
(ConstantValueReference(1), 1),
|
79
|
+
(FixtureState.theta[0], "baz"),
|
78
80
|
],
|
79
81
|
ids=[
|
80
82
|
"or",
|
@@ -119,6 +121,7 @@ class DummyNode(BaseNode[FixtureState]):
|
|
119
121
|
"or_and",
|
120
122
|
"accessor",
|
121
123
|
"constants",
|
124
|
+
"list_index",
|
122
125
|
],
|
123
126
|
)
|
124
127
|
def test_resolve_value__happy_path(descriptor, expected_value):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from collections.abc import Mapping
|
2
2
|
import dataclasses
|
3
|
-
from typing import Any, Sequence, Type, TypeVar
|
3
|
+
from typing import Any, Sequence, Type, TypeVar, Union
|
4
4
|
|
5
5
|
from pydantic import BaseModel, GetCoreSchemaHandler
|
6
6
|
from pydantic_core import core_schema
|
@@ -17,7 +17,7 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
17
17
|
self,
|
18
18
|
*,
|
19
19
|
base: BaseDescriptor[LHS],
|
20
|
-
field: str,
|
20
|
+
field: Union[str, int],
|
21
21
|
) -> None:
|
22
22
|
super().__init__(
|
23
23
|
name=f"{base.name}.{field}",
|
@@ -31,9 +31,15 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
31
31
|
base = resolve_value(self._base, state)
|
32
32
|
|
33
33
|
if dataclasses.is_dataclass(base):
|
34
|
+
if isinstance(self._field, int):
|
35
|
+
raise ValueError("Cannot access field by index on a dataclass")
|
36
|
+
|
34
37
|
return getattr(base, self._field)
|
35
38
|
|
36
39
|
if isinstance(base, BaseModel):
|
40
|
+
if isinstance(self._field, int):
|
41
|
+
raise ValueError("Cannot access field by index on a BaseModel")
|
42
|
+
|
37
43
|
return getattr(base, self._field)
|
38
44
|
|
39
45
|
if isinstance(base, Mapping):
|
@@ -1,7 +1,20 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from queue import Empty, Queue
|
3
3
|
from threading import Thread
|
4
|
-
from typing import
|
4
|
+
from typing import (
|
5
|
+
TYPE_CHECKING,
|
6
|
+
Callable,
|
7
|
+
Dict,
|
8
|
+
Generic,
|
9
|
+
Iterator,
|
10
|
+
List,
|
11
|
+
Optional,
|
12
|
+
Tuple,
|
13
|
+
Type,
|
14
|
+
TypeVar,
|
15
|
+
Union,
|
16
|
+
overload,
|
17
|
+
)
|
5
18
|
|
6
19
|
from vellum.workflows.context import execution_context, get_parent_context
|
7
20
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
@@ -12,6 +25,7 @@ from vellum.workflows.inputs.base import BaseInputs
|
|
12
25
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
13
26
|
from vellum.workflows.nodes.utils import create_adornment
|
14
27
|
from vellum.workflows.outputs import BaseOutputs
|
28
|
+
from vellum.workflows.outputs.base import BaseOutput
|
15
29
|
from vellum.workflows.references.output import OutputReference
|
16
30
|
from vellum.workflows.state.context import WorkflowContext
|
17
31
|
from vellum.workflows.types.generics import StateType
|
@@ -29,11 +43,11 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
29
43
|
|
30
44
|
items: List[MapNodeItemType] - The items to map over
|
31
45
|
subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute on each iteration
|
32
|
-
|
46
|
+
max_concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
|
33
47
|
"""
|
34
48
|
|
35
49
|
items: List[MapNodeItemType]
|
36
|
-
|
50
|
+
max_concurrency: Optional[int] = None
|
37
51
|
|
38
52
|
class Outputs(BaseAdornmentNode.Outputs):
|
39
53
|
pass
|
@@ -47,7 +61,7 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
47
61
|
index: int
|
48
62
|
all_items: List[MapNodeItemType] # type: ignore[valid-type]
|
49
63
|
|
50
|
-
def run(self) ->
|
64
|
+
def run(self) -> Iterator[BaseOutput]:
|
51
65
|
mapped_items: Dict[str, List] = defaultdict(list)
|
52
66
|
for output_descripter in self.subworkflow.Outputs:
|
53
67
|
mapped_items[output_descripter.name] = [None] * len(self.items)
|
@@ -66,14 +80,14 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
66
80
|
"parent_context": parent_context,
|
67
81
|
},
|
68
82
|
)
|
69
|
-
if self.
|
83
|
+
if self.max_concurrency is None:
|
70
84
|
thread.start()
|
71
85
|
else:
|
72
86
|
self._concurrency_queue.put(thread)
|
73
87
|
|
74
|
-
if self.
|
88
|
+
if self.max_concurrency is not None:
|
75
89
|
concurrency_count = 0
|
76
|
-
while concurrency_count < self.
|
90
|
+
while concurrency_count < self.max_concurrency:
|
77
91
|
is_empty = self._start_thread()
|
78
92
|
if is_empty:
|
79
93
|
break
|
@@ -83,40 +97,45 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
83
97
|
try:
|
84
98
|
while map_node_event := self._event_queue.get():
|
85
99
|
index = map_node_event[0]
|
86
|
-
|
87
|
-
self._context._emit_subworkflow_event(
|
100
|
+
subworkflow_event = map_node_event[1]
|
101
|
+
self._context._emit_subworkflow_event(subworkflow_event)
|
88
102
|
|
89
|
-
if
|
90
|
-
|
103
|
+
if subworkflow_event.name == "workflow.execution.initiated":
|
104
|
+
for output_name in mapped_items.keys():
|
105
|
+
yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
|
106
|
+
|
107
|
+
elif subworkflow_event.name == "workflow.execution.fulfilled":
|
108
|
+
workflow_output_vars = vars(subworkflow_event.outputs)
|
91
109
|
|
92
110
|
for output_name in workflow_output_vars:
|
93
111
|
output_mapped_items = mapped_items[output_name]
|
94
112
|
output_mapped_items[index] = workflow_output_vars[output_name]
|
113
|
+
yield BaseOutput(
|
114
|
+
name=output_name,
|
115
|
+
delta=(output_mapped_items[index], index, "FULFILLED"),
|
116
|
+
)
|
95
117
|
|
96
118
|
fulfilled_iterations[index] = True
|
97
119
|
if all(fulfilled_iterations):
|
98
120
|
break
|
99
121
|
|
100
|
-
if self.
|
122
|
+
if self.max_concurrency is not None:
|
101
123
|
self._start_thread()
|
102
|
-
elif
|
124
|
+
elif subworkflow_event.name == "workflow.execution.paused":
|
103
125
|
raise NodeException(
|
104
126
|
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
105
127
|
message=f"Subworkflow unexpectedly paused on iteration {index}",
|
106
128
|
)
|
107
|
-
elif
|
129
|
+
elif subworkflow_event.name == "workflow.execution.rejected":
|
108
130
|
raise NodeException(
|
109
|
-
f"Subworkflow failed on iteration {index} with error: {
|
110
|
-
code=
|
131
|
+
f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
|
132
|
+
code=subworkflow_event.error.code,
|
111
133
|
)
|
112
134
|
except Empty:
|
113
135
|
pass
|
114
136
|
|
115
|
-
outputs = self.Outputs()
|
116
137
|
for output_name, output_list in mapped_items.items():
|
117
|
-
|
118
|
-
|
119
|
-
return outputs
|
138
|
+
yield BaseOutput(name=output_name, value=output_list)
|
120
139
|
|
121
140
|
def _context_run_subworkflow(
|
122
141
|
self, *, item: MapNodeItemType, index: int, parent_context: Optional[ParentContext] = None
|
@@ -149,21 +168,27 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
149
168
|
|
150
169
|
@overload
|
151
170
|
@classmethod
|
152
|
-
def wrap(
|
171
|
+
def wrap(
|
172
|
+
cls, items: List[MapNodeItemType], max_concurrency: Optional[int] = None
|
173
|
+
) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
|
153
174
|
|
154
175
|
# TODO: We should be able to do this overload automatically as we do with node attributes
|
155
176
|
# https://app.shortcut.com/vellum/story/5289
|
156
177
|
@overload
|
157
178
|
@classmethod
|
158
179
|
def wrap(
|
159
|
-
cls,
|
180
|
+
cls,
|
181
|
+
items: BaseDescriptor[List[MapNodeItemType]],
|
182
|
+
max_concurrency: Optional[int] = None,
|
160
183
|
) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
|
161
184
|
|
162
185
|
@classmethod
|
163
186
|
def wrap(
|
164
|
-
cls,
|
187
|
+
cls,
|
188
|
+
items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]],
|
189
|
+
max_concurrency: Optional[int] = None,
|
165
190
|
) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]:
|
166
|
-
return create_adornment(cls, attributes={"items": items})
|
191
|
+
return create_adornment(cls, attributes={"items": items, "max_concurrency": max_concurrency})
|
167
192
|
|
168
193
|
@classmethod
|
169
194
|
def __annotate_outputs_class__(cls, outputs_class: Type[BaseOutputs], reference: OutputReference) -> None:
|
@@ -3,7 +3,7 @@ import time
|
|
3
3
|
from vellum.workflows.inputs.base import BaseInputs
|
4
4
|
from vellum.workflows.nodes.bases import BaseNode
|
5
5
|
from vellum.workflows.nodes.core.map_node.node import MapNode
|
6
|
-
from vellum.workflows.outputs.base import BaseOutputs
|
6
|
+
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
7
7
|
from vellum.workflows.state.base import BaseState, StateMeta
|
8
8
|
|
9
9
|
|
@@ -35,10 +35,10 @@ def test_map_node__use_parent_inputs_and_state():
|
|
35
35
|
meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
|
36
36
|
)
|
37
37
|
)
|
38
|
-
outputs = node.run()
|
38
|
+
outputs = list(node.run())
|
39
39
|
|
40
40
|
# THEN the data is used successfully
|
41
|
-
assert outputs
|
41
|
+
assert outputs[-1] == BaseOutput(name="value", value=["foo bar 1", "foo bar 2", "foo bar 3"])
|
42
42
|
|
43
43
|
|
44
44
|
def test_map_node__use_parallelism():
|
@@ -62,4 +62,4 @@ def test_map_node__use_parallelism():
|
|
62
62
|
|
63
63
|
# THEN the node should have ran in parallel
|
64
64
|
run_time = (end_ts - start_ts) / 10**9
|
65
|
-
assert run_time < 0.
|
65
|
+
assert run_time < 0.2
|
@@ -13,7 +13,7 @@ from vellum.workflows.types.generics import StateType
|
|
13
13
|
|
14
14
|
class BasePromptNode(BaseNode, Generic[StateType]):
|
15
15
|
# Inputs that are passed to the Prompt
|
16
|
-
prompt_inputs: ClassVar[EntityInputsInterface]
|
16
|
+
prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
|
17
17
|
|
18
18
|
request_options: Optional[RequestOptions] = None
|
19
19
|
|
@@ -53,8 +53,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
53
53
|
|
54
54
|
def _get_prompt_event_stream(self) -> Iterator[AdHocExecutePromptEvent]:
|
55
55
|
input_variables, input_values = self._compile_prompt_inputs()
|
56
|
-
|
57
|
-
parent_context = current_parent_context.model_dump_json() if current_parent_context else None
|
56
|
+
parent_context = get_parent_context()
|
58
57
|
request_options = self.request_options or RequestOptions()
|
59
58
|
request_options["additional_body_parameters"] = {
|
60
59
|
"execution_context": {"parent_context": parent_context},
|
@@ -77,13 +76,16 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
77
76
|
blocks=self.blocks,
|
78
77
|
functions=normalized_functions,
|
79
78
|
expand_meta=self.expand_meta,
|
80
|
-
request_options=
|
79
|
+
request_options=request_options,
|
81
80
|
)
|
82
81
|
|
83
82
|
def _compile_prompt_inputs(self) -> Tuple[List[VellumVariable], List[PromptRequestInput]]:
|
84
83
|
input_variables: List[VellumVariable] = []
|
85
84
|
input_values: List[PromptRequestInput] = []
|
86
85
|
|
86
|
+
if not self.prompt_inputs:
|
87
|
+
return input_variables, input_values
|
88
|
+
|
87
89
|
for input_name, input_value in self.prompt_inputs.items():
|
88
90
|
if isinstance(input_value, str):
|
89
91
|
input_variables.append(
|
@@ -74,6 +74,9 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
74
74
|
|
75
75
|
compiled_inputs: List[PromptDeploymentInputRequest] = []
|
76
76
|
|
77
|
+
if not self.prompt_inputs:
|
78
|
+
return compiled_inputs
|
79
|
+
|
77
80
|
for input_name, input_value in self.prompt_inputs.items():
|
78
81
|
if isinstance(input_value, str):
|
79
82
|
compiled_inputs.append(
|
@@ -15,6 +15,7 @@ from vellum.core import ApiError, RequestOptions
|
|
15
15
|
from vellum.workflows.errors import WorkflowErrorCode
|
16
16
|
from vellum.workflows.exceptions import NodeException
|
17
17
|
from vellum.workflows.nodes.bases import BaseNode
|
18
|
+
from vellum.workflows.nodes.displayable.bases.types import SearchFilters
|
18
19
|
from vellum.workflows.outputs import BaseOutputs
|
19
20
|
from vellum.workflows.types.generics import StateType
|
20
21
|
|
@@ -33,7 +34,11 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
33
34
|
document_index: Union[UUID, str] - Either the UUID or name of the Vellum Document Index that you'd like to search
|
34
35
|
against
|
35
36
|
query: str - The query to search for
|
36
|
-
|
37
|
+
limit: Optional[int] = None - The maximum number of results to return.
|
38
|
+
weights: Optional[SearchWeightsRequest] = None - The weights to use for the search. Must add up to 1.0.
|
39
|
+
result_merging: Optional[SearchResultMergingRequest] = None - The configuration for merging results.
|
40
|
+
filters: Optional[SearchFiltersRequest] = None - The filters to apply to the search.
|
41
|
+
options: Optional[SearchRequestOptionsRequest] = None - [DEPRECATED] Runtime configuration for the search
|
37
42
|
request_options: Optional[RequestOptions] = None - The request options to use for the search
|
38
43
|
"""
|
39
44
|
|
@@ -43,11 +48,24 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
43
48
|
# The Document Index to Search against. Identified by either its UUID or its name.
|
44
49
|
document_index: ClassVar[Union[UUID, str]]
|
45
50
|
|
51
|
+
# The maximum number of results to return.
|
52
|
+
limit: ClassVar[Optional[int]] = None
|
53
|
+
|
54
|
+
# The weights to use for the search. Must add up to 1.0.
|
55
|
+
weights: ClassVar[Optional[SearchWeightsRequest]] = None
|
56
|
+
|
57
|
+
# The configuration for merging results.
|
58
|
+
result_merging: ClassVar[Optional[SearchResultMergingRequest]] = None
|
59
|
+
|
60
|
+
# The filters to apply to the search.
|
61
|
+
filters: ClassVar[Optional[SearchFilters]] = None
|
62
|
+
|
46
63
|
# Ideally we could reuse node descriptors to derive other node descriptor values. Two action items are
|
47
64
|
# blocking us from doing so in this use case:
|
48
65
|
# 1. Node Descriptor resolution during runtime - https://app.shortcut.com/vellum/story/4781
|
49
66
|
# 2. Math operations between descriptors - https://app.shortcut.com/vellum/story/4782
|
50
67
|
# search_weights = DEFAULT_SEARCH_WEIGHTS
|
68
|
+
# Deprecated: Use the top level `limit`, `weights`, `result_merging`, and `filters` attributes instead
|
51
69
|
options = SearchRequestOptionsRequest(
|
52
70
|
limit=DEFAULT_SEARCH_LIMIT,
|
53
71
|
weights=SearchWeightsRequest(
|
@@ -77,7 +95,7 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
77
95
|
return self._context.vellum_client.search(
|
78
96
|
query=self.query,
|
79
97
|
document_index=str(self.document_index),
|
80
|
-
options=self.
|
98
|
+
options=self._get_options_request(),
|
81
99
|
)
|
82
100
|
except NotFoundError:
|
83
101
|
raise NodeException(
|
@@ -90,6 +108,23 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
90
108
|
code=WorkflowErrorCode.INTERNAL_ERROR,
|
91
109
|
)
|
92
110
|
|
111
|
+
def _get_options_request(self) -> SearchRequestOptionsRequest:
|
112
|
+
return SearchRequestOptionsRequest(
|
113
|
+
limit=self.limit if self.limit is not None else self.options.limit,
|
114
|
+
weights=self.weights if self.weights is not None else self.options.weights,
|
115
|
+
result_merging=self.result_merging if self.result_merging is not None else self.options.result_merging,
|
116
|
+
filters=self._get_filters_request(),
|
117
|
+
)
|
118
|
+
|
119
|
+
def _get_filters_request(self) -> Optional[SearchFiltersRequest]:
|
120
|
+
if self.filters is None:
|
121
|
+
return self.options.filters
|
122
|
+
|
123
|
+
return SearchFiltersRequest(
|
124
|
+
external_ids=self.filters.external_ids,
|
125
|
+
metadata=self.filters.metadata.to_request() if self.filters.metadata is not None else None,
|
126
|
+
)
|
127
|
+
|
93
128
|
def run(self) -> Outputs:
|
94
129
|
response = self._perform_search()
|
95
130
|
return self.Outputs(results=response.results)
|
File without changes
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import pytest
|
2
|
+
import enum
|
3
|
+
|
4
|
+
from vellum.client.types.chat_history_vellum_value import ChatHistoryVellumValue
|
5
|
+
from vellum.client.types.chat_message import ChatMessage
|
6
|
+
from vellum.client.types.json_vellum_value import JsonVellumValue
|
7
|
+
from vellum.client.types.number_vellum_value import NumberVellumValue
|
8
|
+
from vellum.client.types.search_result import SearchResult
|
9
|
+
from vellum.client.types.search_result_document import SearchResultDocument
|
10
|
+
from vellum.client.types.search_results_vellum_value import SearchResultsVellumValue
|
11
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
12
|
+
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
13
|
+
from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value, primitive_to_vellum_value_request
|
14
|
+
|
15
|
+
|
16
|
+
class MockEnum(enum.Enum):
|
17
|
+
FOO = "foo"
|
18
|
+
|
19
|
+
|
20
|
+
@pytest.mark.parametrize(
|
21
|
+
["value", "expected_output"],
|
22
|
+
[
|
23
|
+
("hello", StringVellumValue(value="hello")),
|
24
|
+
(MockEnum.FOO, StringVellumValue(value="foo")),
|
25
|
+
(1, NumberVellumValue(value=1)),
|
26
|
+
(1.0, NumberVellumValue(value=1.0)),
|
27
|
+
(
|
28
|
+
[ChatMessage(role="USER", text="hello")],
|
29
|
+
ChatHistoryVellumValue(value=[ChatMessage(role="USER", text="hello")]),
|
30
|
+
),
|
31
|
+
(
|
32
|
+
[
|
33
|
+
SearchResult(
|
34
|
+
text="Search query",
|
35
|
+
score="0.0",
|
36
|
+
keywords=["keywords"],
|
37
|
+
document=SearchResultDocument(label="label"),
|
38
|
+
)
|
39
|
+
],
|
40
|
+
SearchResultsVellumValue(
|
41
|
+
value=[
|
42
|
+
SearchResult(
|
43
|
+
text="Search query",
|
44
|
+
score="0.0",
|
45
|
+
keywords=["keywords"],
|
46
|
+
document=SearchResultDocument(label="label"),
|
47
|
+
)
|
48
|
+
]
|
49
|
+
),
|
50
|
+
),
|
51
|
+
(StringVellumValue(value="hello"), StringVellumValue(value="hello")),
|
52
|
+
(StringVellumValueRequest(value="hello"), StringVellumValueRequest(value="hello")),
|
53
|
+
({"foo": "bar"}, JsonVellumValue(value={"foo": "bar"})),
|
54
|
+
],
|
55
|
+
)
|
56
|
+
def test_primitive_to_vellum_value(value, expected_output):
|
57
|
+
assert primitive_to_vellum_value(value) == expected_output
|
58
|
+
|
59
|
+
|
60
|
+
def test_primitive_to_vellum_value_request():
|
61
|
+
assert primitive_to_vellum_value_request("hello") == StringVellumValueRequest(value="hello")
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from typing import Any, List, Optional, Union
|
2
|
+
|
3
|
+
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
4
|
+
from vellum.client.types.condition_combinator import ConditionCombinator
|
5
|
+
from vellum.client.types.logical_operator import LogicalOperator
|
6
|
+
from vellum.client.types.vellum_value_logical_condition_group_request import VellumValueLogicalConditionGroupRequest
|
7
|
+
from vellum.client.types.vellum_value_logical_condition_request import VellumValueLogicalConditionRequest
|
8
|
+
from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value_request
|
9
|
+
|
10
|
+
|
11
|
+
class MetadataLogicalConditionGroup(UniversalBaseModel):
|
12
|
+
combinator: ConditionCombinator
|
13
|
+
negated: bool
|
14
|
+
conditions: List["MetadataLogicalExpression"]
|
15
|
+
|
16
|
+
def to_request(self) -> VellumValueLogicalConditionGroupRequest:
|
17
|
+
return VellumValueLogicalConditionGroupRequest(
|
18
|
+
combinator=self.combinator,
|
19
|
+
negated=self.negated,
|
20
|
+
conditions=[c.to_request() for c in self.conditions],
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class MetadataLogicalCondition(UniversalBaseModel):
|
25
|
+
lhs_variable: Any
|
26
|
+
operator: LogicalOperator
|
27
|
+
rhs_variable: Any
|
28
|
+
|
29
|
+
def to_request(self) -> VellumValueLogicalConditionRequest:
|
30
|
+
return VellumValueLogicalConditionRequest(
|
31
|
+
lhs_variable=primitive_to_vellum_value_request(self.lhs_variable),
|
32
|
+
operator=self.operator,
|
33
|
+
rhs_variable=primitive_to_vellum_value_request(self.rhs_variable),
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
MetadataLogicalExpression = Union[MetadataLogicalConditionGroup, MetadataLogicalCondition]
|
38
|
+
|
39
|
+
|
40
|
+
class SearchFilters(UniversalBaseModel):
|
41
|
+
external_ids: Optional[List[str]] = None
|
42
|
+
metadata: Optional[MetadataLogicalConditionGroup] = None
|