vellum-ai 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/core/pydantic_utilities.py +5 -0
  3. vellum/client/resources/workflows/client.py +8 -0
  4. vellum/client/types/logical_operator.py +2 -0
  5. vellum/workflows/descriptors/base.py +1 -1
  6. vellum/workflows/descriptors/tests/test_utils.py +3 -0
  7. vellum/workflows/expressions/accessor.py +8 -2
  8. vellum/workflows/nodes/core/map_node/node.py +49 -24
  9. vellum/workflows/nodes/core/map_node/tests/test_node.py +4 -4
  10. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
  11. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
  12. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +3 -0
  13. vellum/workflows/nodes/displayable/bases/search_node.py +37 -2
  14. vellum/workflows/nodes/displayable/bases/tests/__init__.py +0 -0
  15. vellum/workflows/nodes/displayable/bases/tests/test_utils.py +61 -0
  16. vellum/workflows/nodes/displayable/bases/types.py +42 -0
  17. vellum/workflows/nodes/displayable/bases/utils.py +112 -0
  18. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +0 -1
  19. vellum/workflows/nodes/displayable/search_node/tests/__init__.py +0 -0
  20. vellum/workflows/nodes/displayable/search_node/tests/test_node.py +164 -0
  21. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +2 -3
  22. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +0 -1
  23. vellum/workflows/runner/runner.py +37 -4
  24. vellum/workflows/types/tests/test_utils.py +5 -2
  25. vellum/workflows/types/utils.py +4 -0
  26. vellum/workflows/workflows/base.py +14 -0
  27. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/METADATA +1 -1
  28. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/RECORD +46 -36
  29. vellum_cli/__init__.py +10 -0
  30. vellum_cli/ping.py +28 -0
  31. vellum_cli/tests/test_ping.py +47 -0
  32. vellum_ee/workflows/display/nodes/vellum/base_node.py +22 -9
  33. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -0
  34. vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
  35. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +14 -10
  36. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  37. vellum_ee/workflows/display/nodes/vellum/utils.py +8 -1
  38. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +67 -0
  39. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +66 -0
  40. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +660 -0
  41. vellum_ee/workflows/display/utils/vellum.py +4 -42
  42. vellum_ee/workflows/display/vellum.py +7 -36
  43. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +2 -1
  44. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/LICENSE +0 -0
  45. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/WHEEL +0 -0
  46. {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.13.0",
21
+ "X-Fern-SDK-Version": "0.13.1",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -133,6 +133,11 @@ class UniversalBaseModel(pydantic.BaseModel):
133
133
  #
134
134
  # We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models
135
135
  # that we have less control over, and this is less intrusive than custom serializers for now.
136
+ kwargs = {
137
+ **kwargs,
138
+ "warnings": False,
139
+ }
140
+
136
141
  if IS_PYDANTIC_V2:
137
142
  kwargs_with_defaults_exclude_unset: typing.Any = {
138
143
  **kwargs,
@@ -101,6 +101,7 @@ class WorkflowsClient:
101
101
  deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
102
102
  artifact: typing.Optional[core.File] = OMIT,
103
103
  dry_run: typing.Optional[bool] = OMIT,
104
+ strict: typing.Optional[bool] = OMIT,
104
105
  request_options: typing.Optional[RequestOptions] = None,
105
106
  ) -> WorkflowPushResponse:
106
107
  """
@@ -120,6 +121,8 @@ class WorkflowsClient:
120
121
 
121
122
  dry_run : typing.Optional[bool]
122
123
 
124
+ strict : typing.Optional[bool]
125
+
123
126
  request_options : typing.Optional[RequestOptions]
124
127
  Request-specific configuration.
125
128
 
@@ -150,6 +153,7 @@ class WorkflowsClient:
150
153
  "workflow_sandbox_id": workflow_sandbox_id,
151
154
  "deployment_config": deployment_config,
152
155
  "dry_run": dry_run,
156
+ "strict": strict,
153
157
  },
154
158
  files={
155
159
  "artifact": artifact,
@@ -254,6 +258,7 @@ class AsyncWorkflowsClient:
254
258
  deployment_config: typing.Optional[WorkflowPushDeploymentConfigRequest] = OMIT,
255
259
  artifact: typing.Optional[core.File] = OMIT,
256
260
  dry_run: typing.Optional[bool] = OMIT,
261
+ strict: typing.Optional[bool] = OMIT,
257
262
  request_options: typing.Optional[RequestOptions] = None,
258
263
  ) -> WorkflowPushResponse:
259
264
  """
@@ -273,6 +278,8 @@ class AsyncWorkflowsClient:
273
278
 
274
279
  dry_run : typing.Optional[bool]
275
280
 
281
+ strict : typing.Optional[bool]
282
+
276
283
  request_options : typing.Optional[RequestOptions]
277
284
  Request-specific configuration.
278
285
 
@@ -311,6 +318,7 @@ class AsyncWorkflowsClient:
311
318
  "workflow_sandbox_id": workflow_sandbox_id,
312
319
  "deployment_config": deployment_config,
313
320
  "dry_run": dry_run,
321
+ "strict": strict,
314
322
  },
315
323
  files={
316
324
  "artifact": artifact,
@@ -26,6 +26,8 @@ LogicalOperator = typing.Union[
26
26
  "notBlank",
27
27
  "coalesce",
28
28
  "accessField",
29
+ "and",
30
+ "or",
29
31
  ],
30
32
  typing.Any,
31
33
  ]
@@ -121,7 +121,7 @@ class BaseDescriptor(Generic[_T]):
121
121
 
122
122
  return CoalesceExpression(lhs=self, rhs=other)
123
123
 
124
- def __getitem__(self, field: str) -> "AccessorExpression":
124
+ def __getitem__(self, field: Union[str, int]) -> "AccessorExpression":
125
125
  from vellum.workflows.expressions.accessor import AccessorExpression
126
126
 
127
127
  return AccessorExpression(base=self, field=field)
@@ -19,6 +19,7 @@ class FixtureState(BaseState):
19
19
  }
20
20
 
21
21
  eta = None
22
+ theta = ["baz"]
22
23
 
23
24
 
24
25
  class DummyNode(BaseNode[FixtureState]):
@@ -75,6 +76,7 @@ class DummyNode(BaseNode[FixtureState]):
75
76
  ),
76
77
  (FixtureState.zeta["foo"], "bar"),
77
78
  (ConstantValueReference(1), 1),
79
+ (FixtureState.theta[0], "baz"),
78
80
  ],
79
81
  ids=[
80
82
  "or",
@@ -119,6 +121,7 @@ class DummyNode(BaseNode[FixtureState]):
119
121
  "or_and",
120
122
  "accessor",
121
123
  "constants",
124
+ "list_index",
122
125
  ],
123
126
  )
124
127
  def test_resolve_value__happy_path(descriptor, expected_value):
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Mapping
2
2
  import dataclasses
3
- from typing import Any, Sequence, Type, TypeVar
3
+ from typing import Any, Sequence, Type, TypeVar, Union
4
4
 
5
5
  from pydantic import BaseModel, GetCoreSchemaHandler
6
6
  from pydantic_core import core_schema
@@ -17,7 +17,7 @@ class AccessorExpression(BaseDescriptor[Any]):
17
17
  self,
18
18
  *,
19
19
  base: BaseDescriptor[LHS],
20
- field: str,
20
+ field: Union[str, int],
21
21
  ) -> None:
22
22
  super().__init__(
23
23
  name=f"{base.name}.{field}",
@@ -31,9 +31,15 @@ class AccessorExpression(BaseDescriptor[Any]):
31
31
  base = resolve_value(self._base, state)
32
32
 
33
33
  if dataclasses.is_dataclass(base):
34
+ if isinstance(self._field, int):
35
+ raise ValueError("Cannot access field by index on a dataclass")
36
+
34
37
  return getattr(base, self._field)
35
38
 
36
39
  if isinstance(base, BaseModel):
40
+ if isinstance(self._field, int):
41
+ raise ValueError("Cannot access field by index on a BaseModel")
42
+
37
43
  return getattr(base, self._field)
38
44
 
39
45
  if isinstance(base, Mapping):
@@ -1,7 +1,20 @@
1
1
  from collections import defaultdict
2
2
  from queue import Empty, Queue
3
3
  from threading import Thread
4
- from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, overload
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Callable,
7
+ Dict,
8
+ Generic,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ overload,
17
+ )
5
18
 
6
19
  from vellum.workflows.context import execution_context, get_parent_context
7
20
  from vellum.workflows.descriptors.base import BaseDescriptor
@@ -12,6 +25,7 @@ from vellum.workflows.inputs.base import BaseInputs
12
25
  from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
13
26
  from vellum.workflows.nodes.utils import create_adornment
14
27
  from vellum.workflows.outputs import BaseOutputs
28
+ from vellum.workflows.outputs.base import BaseOutput
15
29
  from vellum.workflows.references.output import OutputReference
16
30
  from vellum.workflows.state.context import WorkflowContext
17
31
  from vellum.workflows.types.generics import StateType
@@ -29,11 +43,11 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
29
43
 
30
44
  items: List[MapNodeItemType] - The items to map over
31
45
  subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute on each iteration
32
- concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
46
+ max_concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
33
47
  """
34
48
 
35
49
  items: List[MapNodeItemType]
36
- concurrency: Optional[int] = None
50
+ max_concurrency: Optional[int] = None
37
51
 
38
52
  class Outputs(BaseAdornmentNode.Outputs):
39
53
  pass
@@ -47,7 +61,7 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
47
61
  index: int
48
62
  all_items: List[MapNodeItemType] # type: ignore[valid-type]
49
63
 
50
- def run(self) -> Outputs:
64
+ def run(self) -> Iterator[BaseOutput]:
51
65
  mapped_items: Dict[str, List] = defaultdict(list)
52
66
  for output_descripter in self.subworkflow.Outputs:
53
67
  mapped_items[output_descripter.name] = [None] * len(self.items)
@@ -66,14 +80,14 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
66
80
  "parent_context": parent_context,
67
81
  },
68
82
  )
69
- if self.concurrency is None:
83
+ if self.max_concurrency is None:
70
84
  thread.start()
71
85
  else:
72
86
  self._concurrency_queue.put(thread)
73
87
 
74
- if self.concurrency is not None:
88
+ if self.max_concurrency is not None:
75
89
  concurrency_count = 0
76
- while concurrency_count < self.concurrency:
90
+ while concurrency_count < self.max_concurrency:
77
91
  is_empty = self._start_thread()
78
92
  if is_empty:
79
93
  break
@@ -83,40 +97,45 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
83
97
  try:
84
98
  while map_node_event := self._event_queue.get():
85
99
  index = map_node_event[0]
86
- terminal_event = map_node_event[1]
87
- self._context._emit_subworkflow_event(terminal_event)
100
+ subworkflow_event = map_node_event[1]
101
+ self._context._emit_subworkflow_event(subworkflow_event)
88
102
 
89
- if terminal_event.name == "workflow.execution.fulfilled":
90
- workflow_output_vars = vars(terminal_event.outputs)
103
+ if subworkflow_event.name == "workflow.execution.initiated":
104
+ for output_name in mapped_items.keys():
105
+ yield BaseOutput(name=output_name, delta=(None, index, "INITIATED"))
106
+
107
+ elif subworkflow_event.name == "workflow.execution.fulfilled":
108
+ workflow_output_vars = vars(subworkflow_event.outputs)
91
109
 
92
110
  for output_name in workflow_output_vars:
93
111
  output_mapped_items = mapped_items[output_name]
94
112
  output_mapped_items[index] = workflow_output_vars[output_name]
113
+ yield BaseOutput(
114
+ name=output_name,
115
+ delta=(output_mapped_items[index], index, "FULFILLED"),
116
+ )
95
117
 
96
118
  fulfilled_iterations[index] = True
97
119
  if all(fulfilled_iterations):
98
120
  break
99
121
 
100
- if self.concurrency is not None:
122
+ if self.max_concurrency is not None:
101
123
  self._start_thread()
102
- elif terminal_event.name == "workflow.execution.paused":
124
+ elif subworkflow_event.name == "workflow.execution.paused":
103
125
  raise NodeException(
104
126
  code=WorkflowErrorCode.INVALID_OUTPUTS,
105
127
  message=f"Subworkflow unexpectedly paused on iteration {index}",
106
128
  )
107
- elif terminal_event.name == "workflow.execution.rejected":
129
+ elif subworkflow_event.name == "workflow.execution.rejected":
108
130
  raise NodeException(
109
- f"Subworkflow failed on iteration {index} with error: {terminal_event.error.message}",
110
- code=terminal_event.error.code,
131
+ f"Subworkflow failed on iteration {index} with error: {subworkflow_event.error.message}",
132
+ code=subworkflow_event.error.code,
111
133
  )
112
134
  except Empty:
113
135
  pass
114
136
 
115
- outputs = self.Outputs()
116
137
  for output_name, output_list in mapped_items.items():
117
- setattr(outputs, output_name, output_list)
118
-
119
- return outputs
138
+ yield BaseOutput(name=output_name, value=output_list)
120
139
 
121
140
  def _context_run_subworkflow(
122
141
  self, *, item: MapNodeItemType, index: int, parent_context: Optional[ParentContext] = None
@@ -149,21 +168,27 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
149
168
 
150
169
  @overload
151
170
  @classmethod
152
- def wrap(cls, items: List[MapNodeItemType]) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
171
+ def wrap(
172
+ cls, items: List[MapNodeItemType], max_concurrency: Optional[int] = None
173
+ ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
153
174
 
154
175
  # TODO: We should be able to do this overload automatically as we do with node attributes
155
176
  # https://app.shortcut.com/vellum/story/5289
156
177
  @overload
157
178
  @classmethod
158
179
  def wrap(
159
- cls, items: BaseDescriptor[List[MapNodeItemType]]
180
+ cls,
181
+ items: BaseDescriptor[List[MapNodeItemType]],
182
+ max_concurrency: Optional[int] = None,
160
183
  ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
161
184
 
162
185
  @classmethod
163
186
  def wrap(
164
- cls, items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]]
187
+ cls,
188
+ items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]],
189
+ max_concurrency: Optional[int] = None,
165
190
  ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]:
166
- return create_adornment(cls, attributes={"items": items})
191
+ return create_adornment(cls, attributes={"items": items, "max_concurrency": max_concurrency})
167
192
 
168
193
  @classmethod
169
194
  def __annotate_outputs_class__(cls, outputs_class: Type[BaseOutputs], reference: OutputReference) -> None:
@@ -3,7 +3,7 @@ import time
3
3
  from vellum.workflows.inputs.base import BaseInputs
4
4
  from vellum.workflows.nodes.bases import BaseNode
5
5
  from vellum.workflows.nodes.core.map_node.node import MapNode
6
- from vellum.workflows.outputs.base import BaseOutputs
6
+ from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
7
7
  from vellum.workflows.state.base import BaseState, StateMeta
8
8
 
9
9
 
@@ -35,10 +35,10 @@ def test_map_node__use_parent_inputs_and_state():
35
35
  meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
36
36
  )
37
37
  )
38
- outputs = node.run()
38
+ outputs = list(node.run())
39
39
 
40
40
  # THEN the data is used successfully
41
- assert outputs.value == ["foo bar 1", "foo bar 2", "foo bar 3"]
41
+ assert outputs[-1] == BaseOutput(name="value", value=["foo bar 1", "foo bar 2", "foo bar 3"])
42
42
 
43
43
 
44
44
  def test_map_node__use_parallelism():
@@ -62,4 +62,4 @@ def test_map_node__use_parallelism():
62
62
 
63
63
  # THEN the node should have ran in parallel
64
64
  run_time = (end_ts - start_ts) / 10**9
65
- assert run_time < 0.1
65
+ assert run_time < 0.2
@@ -13,7 +13,7 @@ from vellum.workflows.types.generics import StateType
13
13
 
14
14
  class BasePromptNode(BaseNode, Generic[StateType]):
15
15
  # Inputs that are passed to the Prompt
16
- prompt_inputs: ClassVar[EntityInputsInterface]
16
+ prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
17
17
 
18
18
  request_options: Optional[RequestOptions] = None
19
19
 
@@ -53,8 +53,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
53
53
 
54
54
  def _get_prompt_event_stream(self) -> Iterator[AdHocExecutePromptEvent]:
55
55
  input_variables, input_values = self._compile_prompt_inputs()
56
- current_parent_context = get_parent_context()
57
- parent_context = current_parent_context.model_dump_json() if current_parent_context else None
56
+ parent_context = get_parent_context()
58
57
  request_options = self.request_options or RequestOptions()
59
58
  request_options["additional_body_parameters"] = {
60
59
  "execution_context": {"parent_context": parent_context},
@@ -77,13 +76,16 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
77
76
  blocks=self.blocks,
78
77
  functions=normalized_functions,
79
78
  expand_meta=self.expand_meta,
80
- request_options=self.request_options,
79
+ request_options=request_options,
81
80
  )
82
81
 
83
82
  def _compile_prompt_inputs(self) -> Tuple[List[VellumVariable], List[PromptRequestInput]]:
84
83
  input_variables: List[VellumVariable] = []
85
84
  input_values: List[PromptRequestInput] = []
86
85
 
86
+ if not self.prompt_inputs:
87
+ return input_variables, input_values
88
+
87
89
  for input_name, input_value in self.prompt_inputs.items():
88
90
  if isinstance(input_value, str):
89
91
  input_variables.append(
@@ -74,6 +74,9 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
74
74
 
75
75
  compiled_inputs: List[PromptDeploymentInputRequest] = []
76
76
 
77
+ if not self.prompt_inputs:
78
+ return compiled_inputs
79
+
77
80
  for input_name, input_value in self.prompt_inputs.items():
78
81
  if isinstance(input_value, str):
79
82
  compiled_inputs.append(
@@ -15,6 +15,7 @@ from vellum.core import ApiError, RequestOptions
15
15
  from vellum.workflows.errors import WorkflowErrorCode
16
16
  from vellum.workflows.exceptions import NodeException
17
17
  from vellum.workflows.nodes.bases import BaseNode
18
+ from vellum.workflows.nodes.displayable.bases.types import SearchFilters
18
19
  from vellum.workflows.outputs import BaseOutputs
19
20
  from vellum.workflows.types.generics import StateType
20
21
 
@@ -33,7 +34,11 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
33
34
  document_index: Union[UUID, str] - Either the UUID or name of the Vellum Document Index that you'd like to search
34
35
  against
35
36
  query: str - The query to search for
36
- options: Optional[SearchRequestOptionsRequest] = None - Runtime configuration for the search
37
+ limit: Optional[int] = None - The maximum number of results to return.
38
+ weights: Optional[SearchWeightsRequest] = None - The weights to use for the search. Must add up to 1.0.
39
+ result_merging: Optional[SearchResultMergingRequest] = None - The configuration for merging results.
40
+ filters: Optional[SearchFiltersRequest] = None - The filters to apply to the search.
41
+ options: Optional[SearchRequestOptionsRequest] = None - [DEPRECATED] Runtime configuration for the search
37
42
  request_options: Optional[RequestOptions] = None - The request options to use for the search
38
43
  """
39
44
 
@@ -43,11 +48,24 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
43
48
  # The Document Index to Search against. Identified by either its UUID or its name.
44
49
  document_index: ClassVar[Union[UUID, str]]
45
50
 
51
+ # The maximum number of results to return.
52
+ limit: ClassVar[Optional[int]] = None
53
+
54
+ # The weights to use for the search. Must add up to 1.0.
55
+ weights: ClassVar[Optional[SearchWeightsRequest]] = None
56
+
57
+ # The configuration for merging results.
58
+ result_merging: ClassVar[Optional[SearchResultMergingRequest]] = None
59
+
60
+ # The filters to apply to the search.
61
+ filters: ClassVar[Optional[SearchFilters]] = None
62
+
46
63
  # Ideally we could reuse node descriptors to derive other node descriptor values. Two action items are
47
64
  # blocking us from doing so in this use case:
48
65
  # 1. Node Descriptor resolution during runtime - https://app.shortcut.com/vellum/story/4781
49
66
  # 2. Math operations between descriptors - https://app.shortcut.com/vellum/story/4782
50
67
  # search_weights = DEFAULT_SEARCH_WEIGHTS
68
+ # Deprecated: Use the top level `limit`, `weights`, `result_merging`, and `filters` attributes instead
51
69
  options = SearchRequestOptionsRequest(
52
70
  limit=DEFAULT_SEARCH_LIMIT,
53
71
  weights=SearchWeightsRequest(
@@ -77,7 +95,7 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
77
95
  return self._context.vellum_client.search(
78
96
  query=self.query,
79
97
  document_index=str(self.document_index),
80
- options=self.options,
98
+ options=self._get_options_request(),
81
99
  )
82
100
  except NotFoundError:
83
101
  raise NodeException(
@@ -90,6 +108,23 @@ class BaseSearchNode(BaseNode[StateType], Generic[StateType]):
90
108
  code=WorkflowErrorCode.INTERNAL_ERROR,
91
109
  )
92
110
 
111
+ def _get_options_request(self) -> SearchRequestOptionsRequest:
112
+ return SearchRequestOptionsRequest(
113
+ limit=self.limit if self.limit is not None else self.options.limit,
114
+ weights=self.weights if self.weights is not None else self.options.weights,
115
+ result_merging=self.result_merging if self.result_merging is not None else self.options.result_merging,
116
+ filters=self._get_filters_request(),
117
+ )
118
+
119
+ def _get_filters_request(self) -> Optional[SearchFiltersRequest]:
120
+ if self.filters is None:
121
+ return self.options.filters
122
+
123
+ return SearchFiltersRequest(
124
+ external_ids=self.filters.external_ids,
125
+ metadata=self.filters.metadata.to_request() if self.filters.metadata is not None else None,
126
+ )
127
+
93
128
  def run(self) -> Outputs:
94
129
  response = self._perform_search()
95
130
  return self.Outputs(results=response.results)
@@ -0,0 +1,61 @@
1
+ import pytest
2
+ import enum
3
+
4
+ from vellum.client.types.chat_history_vellum_value import ChatHistoryVellumValue
5
+ from vellum.client.types.chat_message import ChatMessage
6
+ from vellum.client.types.json_vellum_value import JsonVellumValue
7
+ from vellum.client.types.number_vellum_value import NumberVellumValue
8
+ from vellum.client.types.search_result import SearchResult
9
+ from vellum.client.types.search_result_document import SearchResultDocument
10
+ from vellum.client.types.search_results_vellum_value import SearchResultsVellumValue
11
+ from vellum.client.types.string_vellum_value import StringVellumValue
12
+ from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
13
+ from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value, primitive_to_vellum_value_request
14
+
15
+
16
+ class MockEnum(enum.Enum):
17
+ FOO = "foo"
18
+
19
+
20
+ @pytest.mark.parametrize(
21
+ ["value", "expected_output"],
22
+ [
23
+ ("hello", StringVellumValue(value="hello")),
24
+ (MockEnum.FOO, StringVellumValue(value="foo")),
25
+ (1, NumberVellumValue(value=1)),
26
+ (1.0, NumberVellumValue(value=1.0)),
27
+ (
28
+ [ChatMessage(role="USER", text="hello")],
29
+ ChatHistoryVellumValue(value=[ChatMessage(role="USER", text="hello")]),
30
+ ),
31
+ (
32
+ [
33
+ SearchResult(
34
+ text="Search query",
35
+ score="0.0",
36
+ keywords=["keywords"],
37
+ document=SearchResultDocument(label="label"),
38
+ )
39
+ ],
40
+ SearchResultsVellumValue(
41
+ value=[
42
+ SearchResult(
43
+ text="Search query",
44
+ score="0.0",
45
+ keywords=["keywords"],
46
+ document=SearchResultDocument(label="label"),
47
+ )
48
+ ]
49
+ ),
50
+ ),
51
+ (StringVellumValue(value="hello"), StringVellumValue(value="hello")),
52
+ (StringVellumValueRequest(value="hello"), StringVellumValueRequest(value="hello")),
53
+ ({"foo": "bar"}, JsonVellumValue(value={"foo": "bar"})),
54
+ ],
55
+ )
56
+ def test_primitive_to_vellum_value(value, expected_output):
57
+ assert primitive_to_vellum_value(value) == expected_output
58
+
59
+
60
+ def test_primitive_to_vellum_value_request():
61
+ assert primitive_to_vellum_value_request("hello") == StringVellumValueRequest(value="hello")
@@ -0,0 +1,42 @@
1
+ from typing import Any, List, Optional, Union
2
+
3
+ from vellum.client.core.pydantic_utilities import UniversalBaseModel
4
+ from vellum.client.types.condition_combinator import ConditionCombinator
5
+ from vellum.client.types.logical_operator import LogicalOperator
6
+ from vellum.client.types.vellum_value_logical_condition_group_request import VellumValueLogicalConditionGroupRequest
7
+ from vellum.client.types.vellum_value_logical_condition_request import VellumValueLogicalConditionRequest
8
+ from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value_request
9
+
10
+
11
+ class MetadataLogicalConditionGroup(UniversalBaseModel):
12
+ combinator: ConditionCombinator
13
+ negated: bool
14
+ conditions: List["MetadataLogicalExpression"]
15
+
16
+ def to_request(self) -> VellumValueLogicalConditionGroupRequest:
17
+ return VellumValueLogicalConditionGroupRequest(
18
+ combinator=self.combinator,
19
+ negated=self.negated,
20
+ conditions=[c.to_request() for c in self.conditions],
21
+ )
22
+
23
+
24
+ class MetadataLogicalCondition(UniversalBaseModel):
25
+ lhs_variable: Any
26
+ operator: LogicalOperator
27
+ rhs_variable: Any
28
+
29
+ def to_request(self) -> VellumValueLogicalConditionRequest:
30
+ return VellumValueLogicalConditionRequest(
31
+ lhs_variable=primitive_to_vellum_value_request(self.lhs_variable),
32
+ operator=self.operator,
33
+ rhs_variable=primitive_to_vellum_value_request(self.rhs_variable),
34
+ )
35
+
36
+
37
+ MetadataLogicalExpression = Union[MetadataLogicalConditionGroup, MetadataLogicalCondition]
38
+
39
+
40
+ class SearchFilters(UniversalBaseModel):
41
+ external_ids: Optional[List[str]] = None
42
+ metadata: Optional[MetadataLogicalConditionGroup] = None