vellum-ai 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/core/pydantic_utilities.py +5 -0
- vellum/client/resources/workflows/client.py +8 -0
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/base.py +1 -1
- vellum/workflows/descriptors/tests/test_utils.py +3 -0
- vellum/workflows/expressions/accessor.py +8 -2
- vellum/workflows/nodes/core/map_node/node.py +49 -24
- vellum/workflows/nodes/core/map_node/tests/test_node.py +4 -4
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +3 -0
- vellum/workflows/nodes/displayable/bases/search_node.py +37 -2
- vellum/workflows/nodes/displayable/bases/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/bases/tests/test_utils.py +61 -0
- vellum/workflows/nodes/displayable/bases/types.py +42 -0
- vellum/workflows/nodes/displayable/bases/utils.py +112 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +0 -1
- vellum/workflows/nodes/displayable/search_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/search_node/tests/test_node.py +164 -0
- vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +2 -3
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +0 -1
- vellum/workflows/runner/runner.py +37 -4
- vellum/workflows/types/tests/test_utils.py +5 -2
- vellum/workflows/types/utils.py +4 -0
- vellum/workflows/workflows/base.py +14 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/RECORD +53 -42
- vellum_cli/__init__.py +24 -0
- vellum_cli/ping.py +28 -0
- vellum_cli/push.py +62 -12
- vellum_cli/tests/test_ping.py +47 -0
- vellum_cli/tests/test_push.py +76 -0
- vellum_ee/workflows/display/nodes/vellum/base_node.py +59 -11
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -0
- vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +14 -10
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +8 -1
- vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +48 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +67 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +286 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_outputs_serialization.py +177 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +666 -14
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_trigger_serialization.py +7 -8
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +35 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +29 -2
- vellum_ee/workflows/display/utils/vellum.py +4 -42
- vellum_ee/workflows/display/vellum.py +7 -36
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +5 -2
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.13.
|
21
|
+
"X-Fern-SDK-Version": "0.13.2",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -133,6 +133,11 @@ class UniversalBaseModel(pydantic.BaseModel):
|
|
133
133
|
#
|
134
134
|
# We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models
|
135
135
|
# that we have less control over, and this is less intrusive than custom serializers for now.
|
136
|
+
kwargs = {
|
137
|
+
**kwargs,
|
138
|
+
"warnings": False,
|
139
|
+
}
|
140
|
+
|
136
141
|
if IS_PYDANTIC_V2:
|
137
142
|
kwargs_with_defaults_exclude_unset: typing.Any = {
|
138
143
|
**kwargs,
|
@@ -101,6 +101,7 @@ class WorkflowsClient:
|
|
101
101
|
deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
|
102
102
|
artifact: typing.Optional[core.File] = OMIT,
|
103
103
|
dry_run: typing.Optional[bool] = OMIT,
|
104
|
+
strict: typing.Optional[bool] = OMIT,
|
104
105
|
request_options: typing.Optional[RequestOptions] = None,
|
105
106
|
) -> WorkflowPushResponse:
|
106
107
|
"""
|
@@ -120,6 +121,8 @@ class WorkflowsClient:
|
|
120
121
|
|
121
122
|
dry_run : typing.Optional[bool]
|
122
123
|
|
124
|
+
strict : typing.Optional[bool]
|
125
|
+
|
123
126
|
request_options : typing.Optional[RequestOptions]
|
124
127
|
Request-specific configuration.
|
125
128
|
|
@@ -150,6 +153,7 @@ class WorkflowsClient:
|
|
150
153
|
"workflow_sandbox_id": workflow_sandbox_id,
|
151
154
|
"deployment_config": deployment_config,
|
152
155
|
"dry_run": dry_run,
|
156
|
+
"strict": strict,
|
153
157
|
},
|
154
158
|
files={
|
155
159
|
"artifact": artifact,
|
@@ -254,6 +258,7 @@ class AsyncWorkflowsClient:
|
|
254
258
|
deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
|
255
259
|
artifact: typing.Optional[core.File] = OMIT,
|
256
260
|
dry_run: typing.Optional[bool] = OMIT,
|
261
|
+
strict: typing.Optional[bool] = OMIT,
|
257
262
|
request_options: typing.Optional[RequestOptions] = None,
|
258
263
|
) -> WorkflowPushResponse:
|
259
264
|
"""
|
@@ -273,6 +278,8 @@ class AsyncWorkflowsClient:
|
|
273
278
|
|
274
279
|
dry_run : typing.Optional[bool]
|
275
280
|
|
281
|
+
strict : typing.Optional[bool]
|
282
|
+
|
276
283
|
request_options : typing.Optional[RequestOptions]
|
277
284
|
Request-specific configuration.
|
278
285
|
|
@@ -311,6 +318,7 @@ class AsyncWorkflowsClient:
|
|
311
318
|
"workflow_sandbox_id": workflow_sandbox_id,
|
312
319
|
"deployment_config": deployment_config,
|
313
320
|
"dry_run": dry_run,
|
321
|
+
"strict": strict,
|
314
322
|
},
|
315
323
|
files={
|
316
324
|
"artifact": artifact,
|
@@ -121,7 +121,7 @@ class BaseDescriptor(Generic[_T]):
|
|
121
121
|
|
122
122
|
return CoalesceExpression(lhs=self, rhs=other)
|
123
123
|
|
124
|
-
def __getitem__(self, field: str) -> "AccessorExpression":
|
124
|
+
def __getitem__(self, field: Union[str, int]) -> "AccessorExpression":
|
125
125
|
from vellum.workflows.expressions.accessor import AccessorExpression
|
126
126
|
|
127
127
|
return AccessorExpression(base=self, field=field)
|
@@ -19,6 +19,7 @@ class FixtureState(BaseState):
|
|
19
19
|
}
|
20
20
|
|
21
21
|
eta = None
|
22
|
+
theta = ["baz"]
|
22
23
|
|
23
24
|
|
24
25
|
class DummyNode(BaseNode[FixtureState]):
|
@@ -75,6 +76,7 @@ class DummyNode(BaseNode[FixtureState]):
|
|
75
76
|
),
|
76
77
|
(FixtureState.zeta["foo"], "bar"),
|
77
78
|
(ConstantValueReference(1), 1),
|
79
|
+
(FixtureState.theta[0], "baz"),
|
78
80
|
],
|
79
81
|
ids=[
|
80
82
|
"or",
|
@@ -119,6 +121,7 @@ class DummyNode(BaseNode[FixtureState]):
|
|
119
121
|
"or_and",
|
120
122
|
"accessor",
|
121
123
|
"constants",
|
124
|
+
"list_index",
|
122
125
|
],
|
123
126
|
)
|
124
127
|
def test_resolve_value__happy_path(descriptor, expected_value):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from collections.abc import Mapping
|
2
2
|
import dataclasses
|
3
|
-
from typing import Any, Sequence, Type, TypeVar
|
3
|
+
from typing import Any, Sequence, Type, TypeVar, Union
|
4
4
|
|
5
5
|
from pydantic import BaseModel, GetCoreSchemaHandler
|
6
6
|
from pydantic_core import core_schema
|
@@ -17,7 +17,7 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
17
17
|
self,
|
18
18
|
*,
|
19
19
|
base: BaseDescriptor[LHS],
|
20
|
-
field: str,
|
20
|
+
field: Union[str, int],
|
21
21
|
) -> None:
|
22
22
|
super().__init__(
|
23
23
|
name=f"{base.name}.{field}",
|
@@ -31,9 +31,15 @@ class AccessorExpression(BaseDescriptor[Any]):
|
|
31
31
|
base = resolve_value(self._base, state)
|
32
32
|
|
33
33
|
if dataclasses.is_dataclass(base):
|
34
|
+
if isinstance(self._field, int):
|
35
|
+
raise ValueError("Cannot access field by index on a dataclass")
|
36
|
+
|
34
37
|
return getattr(base, self._field)
|
35
38
|
|
36
39
|
if isinstance(base, BaseModel):
|
40
|
+
if isinstance(self._field, int):
|
41
|
+
raise ValueError("Cannot access field by index on a BaseModel")
|
42
|
+
|
37
43
|
return getattr(base, self._field)
|
38
44
|
|
39
45
|
if isinstance(base, Mapping):
|
@@ -1,7 +1,20 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from queue import Empty, Queue
|
3
3
|
from threading import Thread
|
4
|
-
from typing import
|
4
|
+
from typing import (
|
5
|
+
TYPE_CHECKING,
|
6
|
+
Callable,
|
7
|
+
Dict,
|
8
|
+
Generic,
|
9
|
+
Iterator,
|
10
|
+
List,
|
11
|
+
Optional,
|
12
|
+
Tuple,
|
13
|
+
Type,
|
14
|
+
TypeVar,
|
15
|
+
Union,
|
16
|
+
overload,
|
17
|
+
)
|
5
18
|
|
6
19
|
from vellum.workflows.context import execution_context, get_parent_context
|
7
20
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
@@ -12,6 +25,7 @@ from vellum.workflows.inputs.base import BaseInputs
|
|
12
25
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
13
26
|
from vellum.workflows.nodes.utils import create_adornment
|
14
27
|
from vellum.workflows.outputs import BaseOutputs
|
28
|
+
from vellum.workflows.outputs.base import BaseOutput
|
15
29
|
from vellum.workflows.references.output import OutputReference
|
16
30
|
from vellum.workflows.state.context import WorkflowContext
|
17
31
|
from vellum.workflows.types.generics import StateType
|
@@ -29,11 +43,11 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
29
43
|
|
30
44
|
items: List[MapNodeItemType] - The items to map over
|
31
45
|
subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute on each iteration
|
32
|
-
|
46
|
+
max_concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
|
33
47
|
"""
|
34
48
|
|
35
49
|
items: List[MapNodeItemType]
|
36
|
-
|
50
|
+
max_concurrency: Optional[int] = None
|
37
51
|
|
38
52
|
class Outputs(BaseAdornmentNode.Outputs):
|
39
53
|
pass
|
@@ -47,7 +61,7 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
47
61
|
index: int
|
48
62
|
all_items: List[MapNodeItemType] # type: ignore[valid-type]
|
49
63
|
|
50
|
-
def run(self) ->
|
64
|
+
def run(self) -> Iterator[BaseOutput]:
|
51
65
|
mapped_items: Dict[str, List] = defaultdict(list)
|
52
66
|
for output_descripter in self.subworkflow.Outputs:
|
53
67
|
mapped_items[output_descripter.name] = [None] * len(self.items)
|
@@ -66,14 +80,14 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
66
80
|
"parent_context": parent_context,
|
67
81
|
},
|
68
82
|
)
|
69
|
-
if self.
|
83
|
+
if self.max_concurrency is None:
|
70
84
|
thread.start()
|
71
85
|
else:
|
72
86
|
self._concurrency_queue.put(thread)
|
73
87
|
|
74
|
-
if self.
|
88
|
+
if self.max_concurrency is not None:
|
75
89
|
concurrency_count = 0
|
76
|
-
while concurrency_count < self.
|
90
|
+
while concurrency_count < self.max_concurrency:
|
77
91
|
is_empty = self._start_thread()
|
78
92
|
if is_empty:
|
79
93
|
break
|
@@ -83,40 +97,45 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
83
97
|
try:
|
84
98
|
while map_node_event := self._event_queue.get():
|
85
99
|
index = map_node_event[0]
|
86
|
-
|
87
|
-
self._context._emit_subworkflow_event(
|
100
|
+
subworkflow_event = map_node_event[1]
|
101
|
+
self._context._emit_subworkflow_event(subworkflow_event)
|
88
102
|
|
89
|
-
if
|
90
|
-
|
103
|
+
if subworkflow_event.name == "workflow.execution.initiated":
|
104
|
+
for output_name in mapped_items.keys():
|
105
|
+
yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
|
106
|
+
|
107
|
+
elif subworkflow_event.name == "workflow.execution.fulfilled":
|
108
|
+
workflow_output_vars = vars(subworkflow_event.outputs)
|
91
109
|
|
92
110
|
for output_name in workflow_output_vars:
|
93
111
|
output_mapped_items = mapped_items[output_name]
|
94
112
|
output_mapped_items[index] = workflow_output_vars[output_name]
|
113
|
+
yield BaseOutput(
|
114
|
+
name=output_name,
|
115
|
+
delta=(output_mapped_items[index], index, "FULFILLED"),
|
116
|
+
)
|
95
117
|
|
96
118
|
fulfilled_iterations[index] = True
|
97
119
|
if all(fulfilled_iterations):
|
98
120
|
break
|
99
121
|
|
100
|
-
if self.
|
122
|
+
if self.max_concurrency is not None:
|
101
123
|
self._start_thread()
|
102
|
-
elif
|
124
|
+
elif subworkflow_event.name == "workflow.execution.paused":
|
103
125
|
raise NodeException(
|
104
126
|
code=WorkflowErrorCode.INVALID_OUTPUTS,
|
105
127
|
message=f"Subworkflow unexpectedly paused on iteration {index}",
|
106
128
|
)
|
107
|
-
elif
|
129
|
+
elif subworkflow_event.name == "workflow.execution.rejected":
|
108
130
|
raise NodeException(
|
109
|
-
f"Subworkflow failed on iteration {index} with error: {
|
110
|
-
code=
|
131
|
+
f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
|
132
|
+
code=subworkflow_event.error.code,
|
111
133
|
)
|
112
134
|
except Empty:
|
113
135
|
pass
|
114
136
|
|
115
|
-
outputs = self.Outputs()
|
116
137
|
for output_name, output_list in mapped_items.items():
|
117
|
-
|
118
|
-
|
119
|
-
return outputs
|
138
|
+
yield BaseOutput(name=output_name, value=output_list)
|
120
139
|
|
121
140
|
def _context_run_subworkflow(
|
122
141
|
self, *, item: MapNodeItemType, index: int, parent_context: Optional[ParentContext] = None
|
@@ -149,21 +168,27 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
149
168
|
|
150
169
|
@overload
|
151
170
|
@classmethod
|
152
|
-
def wrap(
|
171
|
+
def wrap(
|
172
|
+
cls, items: List[MapNodeItemType], max_concurrency: Optional[int] = None
|
173
|
+
) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
|
153
174
|
|
154
175
|
# TODO: We should be able to do this overload automatically as we do with node attributes
|
155
176
|
# https://app.shortcut.com/vellum/story/5289
|
156
177
|
@overload
|
157
178
|
@classmethod
|
158
179
|
def wrap(
|
159
|
-
cls,
|
180
|
+
cls,
|
181
|
+
items: BaseDescriptor[List[MapNodeItemType]],
|
182
|
+
max_concurrency: Optional[int] = None,
|
160
183
|
) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
|
161
184
|
|
162
185
|
@classmethod
|
163
186
|
def wrap(
|
164
|
-
cls,
|
187
|
+
cls,
|
188
|
+
items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]],
|
189
|
+
max_concurrency: Optional[int] = None,
|
165
190
|
) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]:
|
166
|
-
return create_adornment(cls, attributes={"items": items})
|
191
|
+
return create_adornment(cls, attributes={"items": items, "max_concurrency": max_concurrency})
|
167
192
|
|
168
193
|
@classmethod
|
169
194
|
def __annotate_outputs_class__(cls, outputs_class: Type[BaseOutputs], reference: OutputReference) -> None:
|
@@ -3,7 +3,7 @@ import time
|
|
3
3
|
from vellum.workflows.inputs.base import BaseInputs
|
4
4
|
from vellum.workflows.nodes.bases import BaseNode
|
5
5
|
from vellum.workflows.nodes.core.map_node.node import MapNode
|
6
|
-
from vellum.workflows.outputs.base import BaseOutputs
|
6
|
+
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
7
7
|
from vellum.workflows.state.base import BaseState, StateMeta
|
8
8
|
|
9
9
|
|
@@ -35,10 +35,10 @@ def test_map_node__use_parent_inputs_and_state():
|
|
35
35
|
meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
|
36
36
|
)
|
37
37
|
)
|
38
|
-
outputs = node.run()
|
38
|
+
outputs = list(node.run())
|
39
39
|
|
40
40
|
# THEN the data is used successfully
|
41
|
-
assert outputs
|
41
|
+
assert outputs[-1] == BaseOutput(name="value", value=["foo bar 1", "foo bar 2", "foo bar 3"])
|
42
42
|
|
43
43
|
|
44
44
|
def test_map_node__use_parallelism():
|
@@ -62,4 +62,4 @@ def test_map_node__use_parallelism():
|
|
62
62
|
|
63
63
|
# THEN the node should have ran in parallel
|
64
64
|
run_time = (end_ts - start_ts) / 10**9
|
65
|
-
assert run_time < 0.
|
65
|
+
assert run_time < 0.2
|
@@ -13,7 +13,7 @@ from vellum.workflows.types.generics import StateType
|
|
13
13
|
|
14
14
|
class BasePromptNode(BaseNode, Generic[StateType]):
|
15
15
|
# Inputs that are passed to the Prompt
|
16
|
-
prompt_inputs: ClassVar[EntityInputsInterface]
|
16
|
+
prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
|
17
17
|
|
18
18
|
request_options: Optional[RequestOptions] = None
|
19
19
|
|
@@ -53,8 +53,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
53
53
|
|
54
54
|
def _get_prompt_event_stream(self) -> Iterator[AdHocExecutePromptEvent]:
|
55
55
|
input_variables, input_values = self._compile_prompt_inputs()
|
56
|
-
|
57
|
-
parent_context = current_parent_context.model_dump_json() if current_parent_context else None
|
56
|
+
parent_context = get_parent_context()
|
58
57
|
request_options = self.request_options or RequestOptions()
|
59
58
|
request_options["additional_body_parameters"] = {
|
60
59
|
"execution_context": {"parent_context": parent_context},
|
@@ -77,13 +76,16 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
77
76
|
blocks=self.blocks,
|
78
77
|
functions=normalized_functions,
|
79
78
|
expand_meta=self.expand_meta,
|
80
|
-
request_options=
|
79
|
+
request_options=request_options,
|
81
80
|
)
|
82
81
|
|
83
82
|
def _compile_prompt_inputs(self) -> Tuple[List[VellumVariable], List[PromptRequestInput]]:
|
84
83
|
input_variables: List[VellumVariable] = []
|
85
84
|
input_values: List[PromptRequestInput] = []
|
86
85
|
|
86
|
+
if not self.prompt_inputs:
|
87
|
+
return input_variables, input_values
|
88
|
+
|
87
89
|
for input_name, input_value in self.prompt_inputs.items():
|
88
90
|
if isinstance(input_value, str):
|
89
91
|
input_variables.append(
|
@@ -74,6 +74,9 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
74
74
|
|
75
75
|
compiled_inputs: List[PromptDeploymentInputRequest] = []
|
76
76
|
|
77
|
+
if not self.prompt_inputs:
|
78
|
+
return compiled_inputs
|
79
|
+
|
77
80
|
for input_name, input_value in self.prompt_inputs.items():
|
78
81
|
if isinstance(input_value, str):
|
79
82
|
compiled_inputs.append(
|
@@ -15,6 +15,7 @@ from vellum.core import ApiError, RequestOptions
|
|
15
15
|
from vellum.workflows.errors import WorkflowErrorCode
|
16
16
|
from vellum.workflows.exceptions import NodeException
|
17
17
|
from vellum.workflows.nodes.bases import BaseNode
|
18
|
+
from vellum.workflows.nodes.displayable.bases.types import SearchFilters
|
18
19
|
from vellum.workflows.outputs import BaseOutputs
|
19
20
|
from vellum.workflows.types.generics import StateType
|
20
21
|
|
@@ -33,7 +34,11 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
33
34
|
document_index: Union[UUID, str] - Either the UUID or name of the Vellum Document Index that you'd like to search
|
34
35
|
against
|
35
36
|
query: str - The query to search for
|
36
|
-
|
37
|
+
limit: Optional[int] = None - The maximum number of results to return.
|
38
|
+
weights: Optional[SearchWeightsRequest] = None - The weights to use for the search. Must add up to 1.0.
|
39
|
+
result_merging: Optional[SearchResultMergingRequest] = None - The configuration for merging results.
|
40
|
+
filters: Optional[SearchFiltersRequest] = None - The filters to apply to the search.
|
41
|
+
options: Optional[SearchRequestOptionsRequest] = None - [DEPRECATED] Runtime configuration for the search
|
37
42
|
request_options: Optional[RequestOptions] = None - The request options to use for the search
|
38
43
|
"""
|
39
44
|
|
@@ -43,11 +48,24 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
43
48
|
# The Document Index to Search against. Identified by either its UUID or its name.
|
44
49
|
document_index: ClassVar[Union[UUID, str]]
|
45
50
|
|
51
|
+
# The maximum number of results to return.
|
52
|
+
limit: ClassVar[Optional[int]] = None
|
53
|
+
|
54
|
+
# The weights to use for the search. Must add up to 1.0.
|
55
|
+
weights: ClassVar[Optional[SearchWeightsRequest]] = None
|
56
|
+
|
57
|
+
# The configuration for merging results.
|
58
|
+
result_merging: ClassVar[Optional[SearchResultMergingRequest]] = None
|
59
|
+
|
60
|
+
# The filters to apply to the search.
|
61
|
+
filters: ClassVar[Optional[SearchFilters]] = None
|
62
|
+
|
46
63
|
# Ideally we could reuse node descriptors to derive other node descriptor values. Two action items are
|
47
64
|
# blocking us from doing so in this use case:
|
48
65
|
# 1. Node Descriptor resolution during runtime - https://app.shortcut.com/vellum/story/4781
|
49
66
|
# 2. Math operations between descriptors - https://app.shortcut.com/vellum/story/4782
|
50
67
|
# search_weights = DEFAULT_SEARCH_WEIGHTS
|
68
|
+
# Deprecated: Use the top level `limit`, `weights`, `result_merging`, and `filters` attributes instead
|
51
69
|
options = SearchRequestOptionsRequest(
|
52
70
|
limit=DEFAULT_SEARCH_LIMIT,
|
53
71
|
weights=SearchWeightsRequest(
|
@@ -77,7 +95,7 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
77
95
|
return self._context.vellum_client.search(
|
78
96
|
query=self.query,
|
79
97
|
document_index=str(self.document_index),
|
80
|
-
options=self.
|
98
|
+
options=self._get_options_request(),
|
81
99
|
)
|
82
100
|
except NotFoundError:
|
83
101
|
raise NodeException(
|
@@ -90,6 +108,23 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
|
|
90
108
|
code=WorkflowErrorCode.INTERNAL_ERROR,
|
91
109
|
)
|
92
110
|
|
111
|
+
def _get_options_request(self) -> SearchRequestOptionsRequest:
|
112
|
+
return SearchRequestOptionsRequest(
|
113
|
+
limit=self.limit if self.limit is not None else self.options.limit,
|
114
|
+
weights=self.weights if self.weights is not None else self.options.weights,
|
115
|
+
result_merging=self.result_merging if self.result_merging is not None else self.options.result_merging,
|
116
|
+
filters=self._get_filters_request(),
|
117
|
+
)
|
118
|
+
|
119
|
+
def _get_filters_request(self) -> Optional[SearchFiltersRequest]:
|
120
|
+
if self.filters is None:
|
121
|
+
return self.options.filters
|
122
|
+
|
123
|
+
return SearchFiltersRequest(
|
124
|
+
external_ids=self.filters.external_ids,
|
125
|
+
metadata=self.filters.metadata.to_request() if self.filters.metadata is not None else None,
|
126
|
+
)
|
127
|
+
|
93
128
|
def run(self) -> Outputs:
|
94
129
|
response = self._perform_search()
|
95
130
|
return self.Outputs(results=response.results)
|
File without changes
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import pytest
|
2
|
+
import enum
|
3
|
+
|
4
|
+
from vellum.client.types.chat_history_vellum_value import ChatHistoryVellumValue
|
5
|
+
from vellum.client.types.chat_message import ChatMessage
|
6
|
+
from vellum.client.types.json_vellum_value import JsonVellumValue
|
7
|
+
from vellum.client.types.number_vellum_value import NumberVellumValue
|
8
|
+
from vellum.client.types.search_result import SearchResult
|
9
|
+
from vellum.client.types.search_result_document import SearchResultDocument
|
10
|
+
from vellum.client.types.search_results_vellum_value import SearchResultsVellumValue
|
11
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
12
|
+
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
13
|
+
from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value, primitive_to_vellum_value_request
|
14
|
+
|
15
|
+
|
16
|
+
class MockEnum(enum.Enum):
|
17
|
+
FOO = "foo"
|
18
|
+
|
19
|
+
|
20
|
+
@pytest.mark.parametrize(
|
21
|
+
["value", "expected_output"],
|
22
|
+
[
|
23
|
+
("hello", StringVellumValue(value="hello")),
|
24
|
+
(MockEnum.FOO, StringVellumValue(value="foo")),
|
25
|
+
(1, NumberVellumValue(value=1)),
|
26
|
+
(1.0, NumberVellumValue(value=1.0)),
|
27
|
+
(
|
28
|
+
[ChatMessage(role="USER", text="hello")],
|
29
|
+
ChatHistoryVellumValue(value=[ChatMessage(role="USER", text="hello")]),
|
30
|
+
),
|
31
|
+
(
|
32
|
+
[
|
33
|
+
SearchResult(
|
34
|
+
text="Search query",
|
35
|
+
score="0.0",
|
36
|
+
keywords=["keywords"],
|
37
|
+
document=SearchResultDocument(label="label"),
|
38
|
+
)
|
39
|
+
],
|
40
|
+
SearchResultsVellumValue(
|
41
|
+
value=[
|
42
|
+
SearchResult(
|
43
|
+
text="Search query",
|
44
|
+
score="0.0",
|
45
|
+
keywords=["keywords"],
|
46
|
+
document=SearchResultDocument(label="label"),
|
47
|
+
)
|
48
|
+
]
|
49
|
+
),
|
50
|
+
),
|
51
|
+
(StringVellumValue(value="hello"), StringVellumValue(value="hello")),
|
52
|
+
(StringVellumValueRequest(value="hello"), StringVellumValueRequest(value="hello")),
|
53
|
+
({"foo": "bar"}, JsonVellumValue(value={"foo": "bar"})),
|
54
|
+
],
|
55
|
+
)
|
56
|
+
def test_primitive_to_vellum_value(value, expected_output):
|
57
|
+
assert primitive_to_vellum_value(value) == expected_output
|
58
|
+
|
59
|
+
|
60
|
+
def test_primitive_to_vellum_value_request():
|
61
|
+
assert primitive_to_vellum_value_request("hello") == StringVellumValueRequest(value="hello")
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from typing import Any, List, Optional, Union
|
2
|
+
|
3
|
+
from vellum.client.core.pydantic_utilities import UniversalBaseModel
|
4
|
+
from vellum.client.types.condition_combinator import ConditionCombinator
|
5
|
+
from vellum.client.types.logical_operator import LogicalOperator
|
6
|
+
from vellum.client.types.vellum_value_logical_condition_group_request import VellumValueLogicalConditionGroupRequest
|
7
|
+
from vellum.client.types.vellum_value_logical_condition_request import VellumValueLogicalConditionRequest
|
8
|
+
from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value_request
|
9
|
+
|
10
|
+
|
11
|
+
class MetadataLogicalConditionGroup(UniversalBaseModel):
|
12
|
+
combinator: ConditionCombinator
|
13
|
+
negated: bool
|
14
|
+
conditions: List["MetadataLogicalExpression"]
|
15
|
+
|
16
|
+
def to_request(self) -> VellumValueLogicalConditionGroupRequest:
|
17
|
+
return VellumValueLogicalConditionGroupRequest(
|
18
|
+
combinator=self.combinator,
|
19
|
+
negated=self.negated,
|
20
|
+
conditions=[c.to_request() for c in self.conditions],
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class MetadataLogicalCondition(UniversalBaseModel):
|
25
|
+
lhs_variable: Any
|
26
|
+
operator: LogicalOperator
|
27
|
+
rhs_variable: Any
|
28
|
+
|
29
|
+
def to_request(self) -> VellumValueLogicalConditionRequest:
|
30
|
+
return VellumValueLogicalConditionRequest(
|
31
|
+
lhs_variable=primitive_to_vellum_value_request(self.lhs_variable),
|
32
|
+
operator=self.operator,
|
33
|
+
rhs_variable=primitive_to_vellum_value_request(self.rhs_variable),
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
MetadataLogicalExpression = Union[MetadataLogicalConditionGroup, MetadataLogicalCondition]
|
38
|
+
|
39
|
+
|
40
|
+
class SearchFilters(UniversalBaseModel):
|
41
|
+
external_ids: Optional[List[str]] = None
|
42
|
+
metadata: Optional[MetadataLogicalConditionGroup] = None
|