vellum-ai 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/core/pydantic_utilities.py +5 -0
  3. vellum/client/resources/workflows/client.py +8 -0
  4. vellum/client/types/logical_operator.py +2 -0
  5. vellum/workflows/descriptors/base.py +1 -1
  6. vellum/workflows/descriptors/tests/test_utils.py +3 -0
  7. vellum/workflows/expressions/accessor.py +8 -2
  8. vellum/workflows/nodes/core/map_node/node.py +49 -24
  9. vellum/workflows/nodes/core/map_node/tests/test_node.py +4 -4
  10. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
  11. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
  12. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +3 -0
  13. vellum/workflows/nodes/displayable/bases/search_node.py +37 -2
  14. vellum/workflows/nodes/displayable/bases/tests/__init__.py +0 -0
  15. vellum/workflows/nodes/displayable/bases/tests/test_utils.py +61 -0
  16. vellum/workflows/nodes/displayable/bases/types.py +42 -0
  17. vellum/workflows/nodes/displayable/bases/utils.py +112 -0
  18. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +0 -1
  19. vellum/workflows/nodes/displayable/search_node/tests/__init__.py +0 -0
  20. vellum/workflows/nodes/displayable/search_node/tests/test_node.py +164 -0
  21. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +2 -3
  22. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +0 -1
  23. vellum/workflows/runner/runner.py +37 -4
  24. vellum/workflows/types/tests/test_utils.py +5 -2
  25. vellum/workflows/types/utils.py +4 -0
  26. vellum/workflows/workflows/base.py +14 -0
  27. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/METADATA +1 -1
  28. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/RECORD +46 -36
  29. vellum_cli/__init__.py +10 -0
  30. vellum_cli/ping.py +28 -0
  31. vellum_cli/tests/test_ping.py +47 -0
  32. vellum_ee/workflows/display/nodes/vellum/base_node.py +22 -9
  33. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -0
  34. vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
  35. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +14 -10
  36. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  37. vellum_ee/workflows/display/nodes/vellum/utils.py +8 -1
  38. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +67 -0
  39. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +66 -0
  40. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +660 -0
  41. vellum_ee/workflows/display/utils/vellum.py +4 -42
  42. vellum_ee/workflows/display/vellum.py +7 -36
  43. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +2 -1
  44. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/LICENSE +0 -0
  45. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/WHEEL +0 -0
  46. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.13.0",
21
+ "X-Fern-SDK-Version": "0.13.1",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -133,6 +133,11 @@ class UniversalBaseModel(pydantic.BaseModel):
133
133
  #
134
134
  # We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models
135
135
  # that we have less control over, and this is less intrusive than custom serializers for now.
136
+ kwargs = {
137
+ **kwargs,
138
+ "warnings": False,
139
+ }
140
+
136
141
  if IS_PYDANTIC_V2:
137
142
  kwargs_with_defaults_exclude_unset: typing.Any = {
138
143
  **kwargs,
@@ -101,6 +101,7 @@ class WorkflowsClient:
101
101
  deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
102
102
  artifact: typing.Optional[core.File] = OMIT,
103
103
  dry_run: typing.Optional[bool] = OMIT,
104
+ strict: typing.Optional[bool] = OMIT,
104
105
  request_options: typing.Optional[RequestOptions] = None,
105
106
  ) -> WorkflowPushResponse:
106
107
  """
@@ -120,6 +121,8 @@ class WorkflowsClient:
120
121
 
121
122
  dry_run : typing.Optional[bool]
122
123
 
124
+ strict : typing.Optional[bool]
125
+
123
126
  request_options : typing.Optional[RequestOptions]
124
127
  Request-specific configuration.
125
128
 
@@ -150,6 +153,7 @@ class WorkflowsClient:
150
153
  "workflow_sandbox_id": workflow_sandbox_id,
151
154
  "deployment_config": deployment_config,
152
155
  "dry_run": dry_run,
156
+ "strict": strict,
153
157
  },
154
158
  files={
155
159
  "artifact": artifact,
@@ -254,6 +258,7 @@ class AsyncWorkflowsClient:
254
258
  deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
255
259
  artifact: typing.Optional[core.File] = OMIT,
256
260
  dry_run: typing.Optional[bool] = OMIT,
261
+ strict: typing.Optional[bool] = OMIT,
257
262
  request_options: typing.Optional[RequestOptions] = None,
258
263
  ) -> WorkflowPushResponse:
259
264
  """
@@ -273,6 +278,8 @@ class AsyncWorkflowsClient:
273
278
 
274
279
  dry_run : typing.Optional[bool]
275
280
 
281
+ strict : typing.Optional[bool]
282
+
276
283
  request_options : typing.Optional[RequestOptions]
277
284
  Request-specific configuration.
278
285
 
@@ -311,6 +318,7 @@ class AsyncWorkflowsClient:
311
318
  "workflow_sandbox_id": workflow_sandbox_id,
312
319
  "deployment_config": deployment_config,
313
320
  "dry_run": dry_run,
321
+ "strict": strict,
314
322
  },
315
323
  files={
316
324
  "artifact": artifact,
@@ -26,6 +26,8 @@ LogicalOperator = typing.Union[
26
26
  "notBlank",
27
27
  "coalesce",
28
28
  "accessField",
29
+ "and",
30
+ "or",
29
31
  ],
30
32
  typing.Any,
31
33
  ]
@@ -121,7 +121,7 @@ class BaseDescriptor(Generic[_T]):
121
121
 
122
122
  return CoalesceExpression(lhs=self, rhs=other)
123
123
 
124
- def __getitem__(self, field: str) -> "AccessorExpression":
124
+ def __getitem__(self, field: Union[str, int]) -> "AccessorExpression":
125
125
  from vellum.workflows.expressions.accessor import AccessorExpression
126
126
 
127
127
  return AccessorExpression(base=self, field=field)
@@ -19,6 +19,7 @@ class FixtureState(BaseState):
19
19
  }
20
20
 
21
21
  eta = None
22
+ theta = ["baz"]
22
23
 
23
24
 
24
25
  class DummyNode(BaseNode[FixtureState]):
@@ -75,6 +76,7 @@ class DummyNode(BaseNode[FixtureState]):
75
76
  ),
76
77
  (FixtureState.zeta["foo"], "bar"),
77
78
  (ConstantValueReference(1), 1),
79
+ (FixtureState.theta[0], "baz"),
78
80
  ],
79
81
  ids=[
80
82
  "or",
@@ -119,6 +121,7 @@ class DummyNode(BaseNode[FixtureState]):
119
121
  "or_and",
120
122
  "accessor",
121
123
  "constants",
124
+ "list_index",
122
125
  ],
123
126
  )
124
127
  def test_resolve_value__happy_path(descriptor, expected_value):
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Mapping
2
2
  import dataclasses
3
- from typing import Any, Sequence, Type, TypeVar
3
+ from typing import Any, Sequence, Type, TypeVar, Union
4
4
 
5
5
  from pydantic import BaseModel, GetCoreSchemaHandler
6
6
  from pydantic_core import core_schema
@@ -17,7 +17,7 @@ class AccessorExpression(BaseDescriptor[Any]):
17
17
  self,
18
18
  *,
19
19
  base: BaseDescriptor[LHS],
20
- field: str,
20
+ field: Union[str, int],
21
21
  ) -> None:
22
22
  super().__init__(
23
23
  name=f"{base.name}.{field}",
@@ -31,9 +31,15 @@ class AccessorExpression(BaseDescriptor[Any]):
31
31
  base = resolve_value(self._base, state)
32
32
 
33
33
  if dataclasses.is_dataclass(base):
34
+ if isinstance(self._field, int):
35
+ raise ValueError("Cannot access field by index on a dataclass")
36
+
34
37
  return getattr(base, self._field)
35
38
 
36
39
  if isinstance(base, BaseModel):
40
+ if isinstance(self._field, int):
41
+ raise ValueError("Cannot access field by index on a BaseModel")
42
+
37
43
  return getattr(base, self._field)
38
44
 
39
45
  if isinstance(base, Mapping):
@@ -1,7 +1,20 @@
1
1
  from collections import defaultdict
2
2
  from queue import Empty, Queue
3
3
  from threading import Thread
4
- from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, overload
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Callable,
7
+ Dict,
8
+ Generic,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ overload,
17
+ )
5
18
 
6
19
  from vellum.workflows.context import execution_context, get_parent_context
7
20
  from vellum.workflows.descriptors.base import BaseDescriptor
@@ -12,6 +25,7 @@ from vellum.workflows.inputs.base import BaseInputs
12
25
  from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
13
26
  from vellum.workflows.nodes.utils import create_adornment
14
27
  from vellum.workflows.outputs import BaseOutputs
28
+ from vellum.workflows.outputs.base import BaseOutput
15
29
  from vellum.workflows.references.output import OutputReference
16
30
  from vellum.workflows.state.context import WorkflowContext
17
31
  from vellum.workflows.types.generics import StateType
@@ -29,11 +43,11 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
29
43
 
30
44
  items: List[MapNodeItemType] - The items to map over
31
45
  subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute on each iteration
32
- concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
46
+ max_concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
33
47
  """
34
48
 
35
49
  items: List[MapNodeItemType]
36
- concurrency: Optional[int] = None
50
+ max_concurrency: Optional[int] = None
37
51
 
38
52
  class Outputs(BaseAdornmentNode.Outputs):
39
53
  pass
@@ -47,7 +61,7 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
47
61
  index: int
48
62
  all_items: List[MapNodeItemType] # type: ignore[valid-type]
49
63
 
50
- def run(self) -> Outputs:
64
+ def run(self) -> Iterator[BaseOutput]:
51
65
  mapped_items: Dict[str, List] = defaultdict(list)
52
66
  for output_descripter in self.subworkflow.Outputs:
53
67
  mapped_items[output_descripter.name] = [None] * len(self.items)
@@ -66,14 +80,14 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
66
80
  "parent_context": parent_context,
67
81
  },
68
82
  )
69
- if self.concurrency is None:
83
+ if self.max_concurrency is None:
70
84
  thread.start()
71
85
  else:
72
86
  self._concurrency_queue.put(thread)
73
87
 
74
- if self.concurrency is not None:
88
+ if self.max_concurrency is not None:
75
89
  concurrency_count = 0
76
- while concurrency_count < self.concurrency:
90
+ while concurrency_count < self.max_concurrency:
77
91
  is_empty = self._start_thread()
78
92
  if is_empty:
79
93
  break
@@ -83,40 +97,45 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
83
97
  try:
84
98
  while map_node_event := self._event_queue.get():
85
99
  index = map_node_event[0]
86
- terminal_event = map_node_event[1]
87
- self._context._emit_subworkflow_event(terminal_event)
100
+ subworkflow_event = map_node_event[1]
101
+ self._context._emit_subworkflow_event(subworkflow_event)
88
102
 
89
- if terminal_event.name == "workflow.execution.fulfilled":
90
- workflow_output_vars = vars(terminal_event.outputs)
103
+ if subworkflow_event.name == "workflow.execution.initiated":
104
+ for output_name in mapped_items.keys():
105
+ yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
106
+
107
+ elif subworkflow_event.name == "workflow.execution.fulfilled":
108
+ workflow_output_vars = vars(subworkflow_event.outputs)
91
109
 
92
110
  for output_name in workflow_output_vars:
93
111
  output_mapped_items = mapped_items[output_name]
94
112
  output_mapped_items[index] = workflow_output_vars[output_name]
113
+ yield BaseOutput(
114
+ name=output_name,
115
+ delta=(output_mapped_items[index], index, "FULFILLED"),
116
+ )
95
117
 
96
118
  fulfilled_iterations[index] = True
97
119
  if all(fulfilled_iterations):
98
120
  break
99
121
 
100
- if self.concurrency is not None:
122
+ if self.max_concurrency is not None:
101
123
  self._start_thread()
102
- elif terminal_event.name == "workflow.execution.paused":
124
+ elif subworkflow_event.name == "workflow.execution.paused":
103
125
  raise NodeException(
104
126
  code=WorkflowErrorCode.INVALID_OUTPUTS,
105
127
  message=f"Subworkflow unexpectedly paused on iteration {index}",
106
128
  )
107
- elif terminal_event.name == "workflow.execution.rejected":
129
+ elif subworkflow_event.name == "workflow.execution.rejected":
108
130
  raise NodeException(
109
- f"Subworkflow failed on iteration {index} with error: {terminal_event.error.message}",
110
- code=terminal_event.error.code,
131
+ f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
132
+ code=subworkflow_event.error.code,
111
133
  )
112
134
  except Empty:
113
135
  pass
114
136
 
115
- outputs = self.Outputs()
116
137
  for output_name, output_list in mapped_items.items():
117
- setattr(outputs, output_name, output_list)
118
-
119
- return outputs
138
+ yield BaseOutput(name=output_name, value=output_list)
120
139
 
121
140
  def _context_run_subworkflow(
122
141
  self, *, item: MapNodeItemType, index: int, parent_context: Optional[ParentContext] = None
@@ -149,21 +168,27 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
149
168
 
150
169
  @overload
151
170
  @classmethod
152
- def wrap(cls, items: List[MapNodeItemType]) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
171
+ def wrap(
172
+ cls, items: List[MapNodeItemType], max_concurrency: Optional[int] = None
173
+ ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
153
174
 
154
175
  # TODO: We should be able to do this overload automatically as we do with node attributes
155
176
  # https://app.shortcut.com/vellum/story/5289
156
177
  @overload
157
178
  @classmethod
158
179
  def wrap(
159
- cls, items: BaseDescriptor[List[MapNodeItemType]]
180
+ cls,
181
+ items: BaseDescriptor[List[MapNodeItemType]],
182
+ max_concurrency: Optional[int] = None,
160
183
  ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
161
184
 
162
185
  @classmethod
163
186
  def wrap(
164
- cls, items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]]
187
+ cls,
188
+ items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]],
189
+ max_concurrency: Optional[int] = None,
165
190
  ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]:
166
- return create_adornment(cls, attributes={"items": items})
191
+ return create_adornment(cls, attributes={"items": items, "max_concurrency": max_concurrency})
167
192
 
168
193
  @classmethod
169
194
  def __annotate_outputs_class__(cls, outputs_class: Type[BaseOutputs], reference: OutputReference) -> None:
@@ -3,7 +3,7 @@ import time
3
3
  from vellum.workflows.inputs.base import BaseInputs
4
4
  from vellum.workflows.nodes.bases import BaseNode
5
5
  from vellum.workflows.nodes.core.map_node.node import MapNode
6
- from vellum.workflows.outputs.base import BaseOutputs
6
+ from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
7
7
  from vellum.workflows.state.base import BaseState, StateMeta
8
8
 
9
9
 
@@ -35,10 +35,10 @@ def test_map_node__use_parent_inputs_and_state():
35
35
  meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
36
36
  )
37
37
  )
38
- outputs = node.run()
38
+ outputs = list(node.run())
39
39
 
40
40
  # THEN the data is used successfully
41
- assert outputs.value == ["foo bar 1", "foo bar 2", "foo bar 3"]
41
+ assert outputs[-1] == BaseOutput(name="value", value=["foo bar 1", "foo bar 2", "foo bar 3"])
42
42
 
43
43
 
44
44
  def test_map_node__use_parallelism():
@@ -62,4 +62,4 @@ def test_map_node__use_parallelism():
62
62
 
63
63
  # THEN the node should have ran in parallel
64
64
  run_time = (end_ts - start_ts) / 10**9
65
- assert run_time < 0.1
65
+ assert run_time < 0.2
@@ -13,7 +13,7 @@ from vellum.workflows.types.generics import StateType
13
13
 
14
14
  class BasePromptNode(BaseNode, Generic[StateType]):
15
15
  # Inputs that are passed to the Prompt
16
- prompt_inputs: ClassVar[EntityInputsInterface]
16
+ prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
17
17
 
18
18
  request_options: Optional[RequestOptions] = None
19
19
 
@@ -53,8 +53,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
53
53
 
54
54
  def _get_prompt_event_stream(self) -> Iterator[AdHocExecutePromptEvent]:
55
55
  input_variables, input_values = self._compile_prompt_inputs()
56
- current_parent_context = get_parent_context()
57
- parent_context = current_parent_context.model_dump_json() if current_parent_context else None
56
+ parent_context = get_parent_context()
58
57
  request_options = self.request_options or RequestOptions()
59
58
  request_options["additional_body_parameters"] = {
60
59
  "execution_context": {"parent_context": parent_context},
@@ -77,13 +76,16 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
77
76
  blocks=self.blocks,
78
77
  functions=normalized_functions,
79
78
  expand_meta=self.expand_meta,
80
- request_options=self.request_options,
79
+ request_options=request_options,
81
80
  )
82
81
 
83
82
  def _compile_prompt_inputs(self) -> Tuple[List[VellumVariable], List[PromptRequestInput]]:
84
83
  input_variables: List[VellumVariable] = []
85
84
  input_values: List[PromptRequestInput] = []
86
85
 
86
+ if not self.prompt_inputs:
87
+ return input_variables, input_values
88
+
87
89
  for input_name, input_value in self.prompt_inputs.items():
88
90
  if isinstance(input_value, str):
89
91
  input_variables.append(
@@ -74,6 +74,9 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
74
74
 
75
75
  compiled_inputs: List[PromptDeploymentInputRequest] = []
76
76
 
77
+ if not self.prompt_inputs:
78
+ return compiled_inputs
79
+
77
80
  for input_name, input_value in self.prompt_inputs.items():
78
81
  if isinstance(input_value, str):
79
82
  compiled_inputs.append(
@@ -15,6 +15,7 @@ from vellum.core import ApiError, RequestOptions
15
15
  from vellum.workflows.errors import WorkflowErrorCode
16
16
  from vellum.workflows.exceptions import NodeException
17
17
  from vellum.workflows.nodes.bases import BaseNode
18
+ from vellum.workflows.nodes.displayable.bases.types import SearchFilters
18
19
  from vellum.workflows.outputs import BaseOutputs
19
20
  from vellum.workflows.types.generics import StateType
20
21
 
@@ -33,7 +34,11 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
33
34
  document_index: Union[UUID, str] - Either the UUID or name of the Vellum Document Index that you'd like to search
34
35
  against
35
36
  query: str - The query to search for
36
- options: Optional[SearchRequestOptionsRequest] = None - Runtime configuration for the search
37
+ limit: Optional[int] = None - The maximum number of results to return.
38
+ weights: Optional[SearchWeightsRequest] = None - The weights to use for the search. Must add up to 1.0.
39
+ result_merging: Optional[SearchResultMergingRequest] = None - The configuration for merging results.
40
+ filters: Optional[SearchFiltersRequest] = None - The filters to apply to the search.
41
+ options: Optional[SearchRequestOptionsRequest] = None - [DEPRECATED] Runtime configuration for the search
37
42
  request_options: Optional[RequestOptions] = None - The request options to use for the search
38
43
  """
39
44
 
@@ -43,11 +48,24 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
43
48
  # The Document Index to Search against. Identified by either its UUID or its name.
44
49
  document_index: ClassVar[Union[UUID, str]]
45
50
 
51
+ # The maximum number of results to return.
52
+ limit: ClassVar[Optional[int]] = None
53
+
54
+ # The weights to use for the search. Must add up to 1.0.
55
+ weights: ClassVar[Optional[SearchWeightsRequest]] = None
56
+
57
+ # The configuration for merging results.
58
+ result_merging: ClassVar[Optional[SearchResultMergingRequest]] = None
59
+
60
+ # The filters to apply to the search.
61
+ filters: ClassVar[Optional[SearchFilters]] = None
62
+
46
63
  # Ideally we could reuse node descriptors to derive other node descriptor values. Two action items are
47
64
  # blocking us from doing so in this use case:
48
65
  # 1. Node Descriptor resolution during runtime - https://app.shortcut.com/vellum/story/4781
49
66
  # 2. Math operations between descriptors - https://app.shortcut.com/vellum/story/4782
50
67
  # search_weights = DEFAULT_SEARCH_WEIGHTS
68
+ # Deprecated: Use the top level `limit`, `weights`, `result_merging`, and `filters` attributes instead
51
69
  options = SearchRequestOptionsRequest(
52
70
  limit=DEFAULT_SEARCH_LIMIT,
53
71
  weights=SearchWeightsRequest(
@@ -77,7 +95,7 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
77
95
  return self._context.vellum_client.search(
78
96
  query=self.query,
79
97
  document_index=str(self.document_index),
80
- options=self.options,
98
+ options=self._get_options_request(),
81
99
  )
82
100
  except NotFoundError:
83
101
  raise NodeException(
@@ -90,6 +108,23 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
90
108
  code=WorkflowErrorCode.INTERNAL_ERROR,
91
109
  )
92
110
 
111
+ def _get_options_request(self) -> SearchRequestOptionsRequest:
112
+ return SearchRequestOptionsRequest(
113
+ limit=self.limit if self.limit is not None else self.options.limit,
114
+ weights=self.weights if self.weights is not None else self.options.weights,
115
+ result_merging=self.result_merging if self.result_merging is not None else self.options.result_merging,
116
+ filters=self._get_filters_request(),
117
+ )
118
+
119
+ def _get_filters_request(self) -> Optional[SearchFiltersRequest]:
120
+ if self.filters is None:
121
+ return self.options.filters
122
+
123
+ return SearchFiltersRequest(
124
+ external_ids=self.filters.external_ids,
125
+ metadata=self.filters.metadata.to_request() if self.filters.metadata is not None else None,
126
+ )
127
+
93
128
  def run(self) -> Outputs:
94
129
  response = self._perform_search()
95
130
  return self.Outputs(results=response.results)
@@ -0,0 +1,61 @@
1
+ import pytest
2
+ import enum
3
+
4
+ from vellum.client.types.chat_history_vellum_value import ChatHistoryVellumValue
5
+ from vellum.client.types.chat_message import ChatMessage
6
+ from vellum.client.types.json_vellum_value import JsonVellumValue
7
+ from vellum.client.types.number_vellum_value import NumberVellumValue
8
+ from vellum.client.types.search_result import SearchResult
9
+ from vellum.client.types.search_result_document import SearchResultDocument
10
+ from vellum.client.types.search_results_vellum_value import SearchResultsVellumValue
11
+ from vellum.client.types.string_vellum_value import StringVellumValue
12
+ from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
13
+ from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value, primitive_to_vellum_value_request
14
+
15
+
16
+ class MockEnum(enum.Enum):
17
+ FOO = "foo"
18
+
19
+
20
+ @pytest.mark.parametrize(
21
+ ["value", "expected_output"],
22
+ [
23
+ ("hello", StringVellumValue(value="hello")),
24
+ (MockEnum.FOO, StringVellumValue(value="foo")),
25
+ (1, NumberVellumValue(value=1)),
26
+ (1.0, NumberVellumValue(value=1.0)),
27
+ (
28
+ [ChatMessage(role="USER", text="hello")],
29
+ ChatHistoryVellumValue(value=[ChatMessage(role="USER", text="hello")]),
30
+ ),
31
+ (
32
+ [
33
+ SearchResult(
34
+ text="Search query",
35
+ score="0.0",
36
+ keywords=["keywords"],
37
+ document=SearchResultDocument(label="label"),
38
+ )
39
+ ],
40
+ SearchResultsVellumValue(
41
+ value=[
42
+ SearchResult(
43
+ text="Search query",
44
+ score="0.0",
45
+ keywords=["keywords"],
46
+ document=SearchResultDocument(label="label"),
47
+ )
48
+ ]
49
+ ),
50
+ ),
51
+ (StringVellumValue(value="hello"), StringVellumValue(value="hello")),
52
+ (StringVellumValueRequest(value="hello"), StringVellumValueRequest(value="hello")),
53
+ ({"foo": "bar"}, JsonVellumValue(value={"foo": "bar"})),
54
+ ],
55
+ )
56
+ def test_primitive_to_vellum_value(value, expected_output):
57
+ assert primitive_to_vellum_value(value) == expected_output
58
+
59
+
60
+ def test_primitive_to_vellum_value_request():
61
+ assert primitive_to_vellum_value_request("hello") == StringVellumValueRequest(value="hello")
@@ -0,0 +1,42 @@
1
+ from typing import Any, List, Optional, Union
2
+
3
+ from vellum.client.core.pydantic_utilities import UniversalBaseModel
4
+ from vellum.client.types.condition_combinator import ConditionCombinator
5
+ from vellum.client.types.logical_operator import LogicalOperator
6
+ from vellum.client.types.vellum_value_logical_condition_group_request import VellumValueLogicalConditionGroupRequest
7
+ from vellum.client.types.vellum_value_logical_condition_request import VellumValueLogicalConditionRequest
8
+ from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value_request
9
+
10
+
11
+ class MetadataLogicalConditionGroup(UniversalBaseModel):
12
+ combinator: ConditionCombinator
13
+ negated: bool
14
+ conditions: List["MetadataLogicalExpression"]
15
+
16
+ def to_request(self) -> VellumValueLogicalConditionGroupRequest:
17
+ return VellumValueLogicalConditionGroupRequest(
18
+ combinator=self.combinator,
19
+ negated=self.negated,
20
+ conditions=[c.to_request() for c in self.conditions],
21
+ )
22
+
23
+
24
+ class MetadataLogicalCondition(UniversalBaseModel):
25
+ lhs_variable: Any
26
+ operator: LogicalOperator
27
+ rhs_variable: Any
28
+
29
+ def to_request(self) -> VellumValueLogicalConditionRequest:
30
+ return VellumValueLogicalConditionRequest(
31
+ lhs_variable=primitive_to_vellum_value_request(self.lhs_variable),
32
+ operator=self.operator,
33
+ rhs_variable=primitive_to_vellum_value_request(self.rhs_variable),
34
+ )
35
+
36
+
37
+ MetadataLogicalExpression = Union[MetadataLogicalConditionGroup, MetadataLogicalCondition]
38
+
39
+
40
+ class SearchFilters(UniversalBaseModel):
41
+ external_ids: Optional[List[str]] = None
42
+ metadata: Optional[MetadataLogicalConditionGroup] = None