vellum-ai 0.14.44__py3-none-any.whl → 0.14.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/core/pydantic_utilities.py +7 -1
- vellum/workflows/nodes/bases/base.py +1 -0
- vellum/workflows/nodes/bases/tests/test_base_node.py +20 -0
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +8 -14
- vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +62 -0
- vellum/workflows/nodes/displayable/code_execution_node/utils.py +3 -54
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +5 -6
- vellum/workflows/nodes/utils.py +4 -0
- vellum/workflows/ports/port.py +13 -3
- vellum/workflows/types/code_execution_node_wrappers.py +64 -0
- vellum/workflows/types/tests/test_utils.py +3 -3
- vellum/workflows/types/utils.py +31 -10
- vellum/workflows/vellum_client.py +19 -7
- {vellum_ai-0.14.44.dist-info → vellum_ai-0.14.46.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.44.dist-info → vellum_ai-0.14.46.dist-info}/RECORD +56 -53
- vellum_cli/config.py +7 -2
- vellum_cli/push.py +5 -1
- vellum_cli/tests/test_push.py +192 -8
- vellum_ee/workflows/display/nodes/base_node_display.py +4 -173
- vellum_ee/workflows/display/nodes/vellum/conditional_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/final_output_node.py +2 -1
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +5 -6
- vellum_ee/workflows/display/nodes/vellum/retry_node.py +3 -3
- vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +5 -6
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_deployment_node.py +106 -0
- vellum_ee/workflows/display/nodes/vellum/tests/test_subworkflow_deployment_node.py +109 -0
- vellum_ee/workflows/display/nodes/vellum/try_node.py +3 -3
- vellum_ee/workflows/display/tests/test_base_workflow_display.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +73 -111
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +0 -3
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +0 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +18 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +10 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +2 -3
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +2 -3
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +5 -55
- vellum_ee/workflows/display/types.py +3 -0
- vellum_ee/workflows/display/utils/expressions.py +222 -2
- vellum_ee/workflows/display/utils/vellum.py +1 -79
- vellum_ee/workflows/display/workflows/base_workflow_display.py +59 -37
- vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +3 -0
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +98 -0
- {vellum_ai-0.14.44.dist-info → vellum_ai-0.14.46.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.44.dist-info → vellum_ai-0.14.46.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.44.dist-info → vellum_ai-0.14.46.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.14.
|
21
|
+
"X-Fern-SDK-Version": "0.14.46",
|
22
22
|
}
|
23
23
|
headers["X-API-KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -8,10 +8,13 @@ from collections import defaultdict
|
|
8
8
|
import typing_extensions
|
9
9
|
|
10
10
|
import pydantic
|
11
|
+
import logging
|
11
12
|
|
12
13
|
from .datetime_utils import serialize_datetime
|
13
14
|
from .serialization import convert_and_respect_annotation_metadata
|
14
15
|
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
15
18
|
IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
16
19
|
|
17
20
|
if IS_PYDANTIC_V2:
|
@@ -245,7 +248,10 @@ def update_forward_refs(model: typing.Type["Model"], **localns: typing.Any) -> N
|
|
245
248
|
if IS_PYDANTIC_V2:
|
246
249
|
model.model_rebuild(raise_errors=False) # type: ignore # Pydantic v2
|
247
250
|
else:
|
248
|
-
|
251
|
+
try:
|
252
|
+
model.update_forward_refs(**localns)
|
253
|
+
except Exception as e:
|
254
|
+
logger.warning("[WARN] Failed to update forward refs for model %s", model.__name__)
|
249
255
|
|
250
256
|
|
251
257
|
# Mirrors Pydantic's internal typing
|
@@ -111,6 +111,7 @@ class BaseNodeMeta(type):
|
|
111
111
|
# Add cls to relevant nested classes, since python should've been doing this by default
|
112
112
|
for port in node_class.Ports:
|
113
113
|
port.node_class = node_class
|
114
|
+
port.validate()
|
114
115
|
|
115
116
|
node_class.Execution.node_class = node_class
|
116
117
|
node_class.Trigger.node_class = node_class
|
@@ -283,3 +283,23 @@ def test_node_outputs__inherits_instance():
|
|
283
283
|
assert foo_output.instance is undefined
|
284
284
|
assert isinstance(bar_output, OutputReference)
|
285
285
|
assert bar_output.instance == "hello"
|
286
|
+
|
287
|
+
|
288
|
+
def test_base_node__iterate_over_attributes__preserves_order():
|
289
|
+
# GIVEN a node with two attributes
|
290
|
+
class MyNode(BaseNode):
|
291
|
+
foo = "foo"
|
292
|
+
bar = "bar"
|
293
|
+
|
294
|
+
# AND a node that inherits from MyNode
|
295
|
+
class InheritedNode(MyNode):
|
296
|
+
baz = "baz"
|
297
|
+
qux = "qux"
|
298
|
+
quux = "quux"
|
299
|
+
|
300
|
+
# WHEN we iterate over the attributes, multiple times
|
301
|
+
for i in range(10):
|
302
|
+
attribute_names = [attr.name for attr in InheritedNode]
|
303
|
+
|
304
|
+
# THEN the attributes are in the correct order
|
305
|
+
assert attribute_names == ["baz", "qux", "quux", "foo", "bar"], f"Iteration {i} failed"
|
@@ -15,7 +15,7 @@ from vellum import (
|
|
15
15
|
)
|
16
16
|
from vellum.client import ApiError, RequestOptions
|
17
17
|
from vellum.client.types.chat_message_request import ChatMessageRequest
|
18
|
-
from vellum.workflows.constants import LATEST_RELEASE_TAG
|
18
|
+
from vellum.workflows.constants import LATEST_RELEASE_TAG
|
19
19
|
from vellum.workflows.context import get_execution_context
|
20
20
|
from vellum.workflows.errors import WorkflowErrorCode
|
21
21
|
from vellum.workflows.errors.types import vellum_error_to_workflow_error
|
@@ -48,13 +48,13 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
48
48
|
deployment: ClassVar[Union[UUID, str]]
|
49
49
|
|
50
50
|
release_tag: str = LATEST_RELEASE_TAG
|
51
|
-
external_id: Optional[str] =
|
51
|
+
external_id: Optional[str] = None
|
52
52
|
|
53
|
-
expand_meta: Optional[PromptDeploymentExpandMetaRequest] =
|
54
|
-
raw_overrides: Optional[RawPromptExecutionOverridesRequest] =
|
55
|
-
expand_raw: Optional[Sequence[str]] =
|
56
|
-
metadata: Optional[Dict[str, Optional[Any]]] =
|
57
|
-
ml_model_fallbacks: Optional[Sequence[str]] =
|
53
|
+
expand_meta: Optional[PromptDeploymentExpandMetaRequest] = None
|
54
|
+
raw_overrides: Optional[RawPromptExecutionOverridesRequest] = None
|
55
|
+
expand_raw: Optional[Sequence[str]] = None
|
56
|
+
metadata: Optional[Dict[str, Optional[Any]]] = None
|
57
|
+
ml_model_fallbacks: Optional[Sequence[str]] = None
|
58
58
|
|
59
59
|
class Trigger(BasePromptNode.Trigger):
|
60
60
|
merge_behavior = MergeBehavior.AWAIT_ANY
|
@@ -103,12 +103,7 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
103
103
|
prompt_event_stream = self._get_prompt_event_stream()
|
104
104
|
next(prompt_event_stream)
|
105
105
|
except ApiError as e:
|
106
|
-
if
|
107
|
-
e.status_code
|
108
|
-
and e.status_code < 500
|
109
|
-
and self.ml_model_fallbacks is not OMIT
|
110
|
-
and self.ml_model_fallbacks is not None
|
111
|
-
):
|
106
|
+
if e.status_code and e.status_code < 500 and self.ml_model_fallbacks is not None:
|
112
107
|
prompt_event_stream = self._retry_prompt_stream_with_fallbacks(tried_fallbacks)
|
113
108
|
else:
|
114
109
|
self._handle_api_error(e)
|
@@ -127,7 +122,6 @@ class BasePromptDeploymentNode(BasePromptNode, Generic[StateType]):
|
|
127
122
|
if (
|
128
123
|
event.error
|
129
124
|
and event.error.code == WorkflowErrorCode.PROVIDER_ERROR.value
|
130
|
-
and self.ml_model_fallbacks is not OMIT
|
131
125
|
and self.ml_model_fallbacks is not None
|
132
126
|
):
|
133
127
|
try:
|
@@ -821,3 +821,65 @@ def main(arg1: list) -> str:
|
|
821
821
|
|
822
822
|
# AND the result should be the correct output
|
823
823
|
assert outputs == {"result": "bar", "log": ""}
|
824
|
+
|
825
|
+
|
826
|
+
def test_run_node__string_value_wrapper__get_attr():
|
827
|
+
# GIVEN a node that accesses the 'value' property of a string input
|
828
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, str]):
|
829
|
+
code = """\
|
830
|
+
def main(text: str) -> str:
|
831
|
+
return text.value
|
832
|
+
"""
|
833
|
+
code_inputs = {
|
834
|
+
"text": "hello",
|
835
|
+
}
|
836
|
+
runtime = "PYTHON_3_11_6"
|
837
|
+
|
838
|
+
# WHEN we run the node
|
839
|
+
node = ExampleCodeExecutionNode()
|
840
|
+
outputs = node.run()
|
841
|
+
|
842
|
+
# THEN the node should successfully access the string value through the .value property
|
843
|
+
assert outputs == {"result": "hello", "log": ""}
|
844
|
+
|
845
|
+
|
846
|
+
def test_run_node__string_value_wrapper__get_item():
|
847
|
+
# GIVEN a node that accesses the 'value' property of a string input
|
848
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, str]):
|
849
|
+
code = """\
|
850
|
+
def main(text: str) -> str:
|
851
|
+
return text["value"]
|
852
|
+
"""
|
853
|
+
code_inputs = {
|
854
|
+
"text": "hello",
|
855
|
+
}
|
856
|
+
runtime = "PYTHON_3_11_6"
|
857
|
+
|
858
|
+
# WHEN we run the node
|
859
|
+
node = ExampleCodeExecutionNode()
|
860
|
+
outputs = node.run()
|
861
|
+
|
862
|
+
# THEN the node should successfully access the string value through the .value property
|
863
|
+
assert outputs == {"result": "hello", "log": ""}
|
864
|
+
|
865
|
+
|
866
|
+
def test_run_node__string_value_wrapper__list_of_dicts():
|
867
|
+
# GIVEN a node that accesses the 'value' property of a string input
|
868
|
+
class ExampleCodeExecutionNode(CodeExecutionNode[BaseState, Any]):
|
869
|
+
code = """\
|
870
|
+
def main(output: list[str]) -> list[str]:
|
871
|
+
results = []
|
872
|
+
for item in output:
|
873
|
+
results.append(item['value'])
|
874
|
+
|
875
|
+
return results
|
876
|
+
"""
|
877
|
+
code_inputs = {"output": ['{"foo": "bar"}', '{"foo2": "bar2"}']}
|
878
|
+
runtime = "PYTHON_3_11_6"
|
879
|
+
|
880
|
+
# WHEN we run the node
|
881
|
+
node = ExampleCodeExecutionNode()
|
882
|
+
outputs = node.run()
|
883
|
+
|
884
|
+
# THEN the node should successfully access the string value
|
885
|
+
assert outputs == {"result": ['{"foo": "bar"}', '{"foo2": "bar2"}'], "log": ""}
|
@@ -8,6 +8,7 @@ from vellum.workflows.errors.types import WorkflowErrorCode
|
|
8
8
|
from vellum.workflows.exceptions import NodeException
|
9
9
|
from vellum.workflows.nodes.utils import cast_to_output_type
|
10
10
|
from vellum.workflows.state.context import WorkflowContext
|
11
|
+
from vellum.workflows.types.code_execution_node_wrappers import ListWrapper, clean_for_dict_wrapper
|
11
12
|
from vellum.workflows.types.core import EntityInputsInterface
|
12
13
|
|
13
14
|
|
@@ -35,58 +36,6 @@ def read_file_from_path(
|
|
35
36
|
return None
|
36
37
|
|
37
38
|
|
38
|
-
class ListWrapper(list):
|
39
|
-
def __getitem__(self, key):
|
40
|
-
item = super().__getitem__(key)
|
41
|
-
if not isinstance(item, DictWrapper) and not isinstance(item, ListWrapper):
|
42
|
-
self.__setitem__(key, _clean_for_dict_wrapper(item))
|
43
|
-
|
44
|
-
return super().__getitem__(key)
|
45
|
-
|
46
|
-
|
47
|
-
class DictWrapper(dict):
|
48
|
-
"""
|
49
|
-
This wraps a dict object to make it behave basically the same as a standard javascript object
|
50
|
-
and enables us to use vellum types here without a shared library since we don't actually
|
51
|
-
typecheck things here.
|
52
|
-
"""
|
53
|
-
|
54
|
-
def __getitem__(self, key):
|
55
|
-
return self.__getattr__(key)
|
56
|
-
|
57
|
-
def __getattr__(self, attr):
|
58
|
-
if attr not in self:
|
59
|
-
if attr == "value":
|
60
|
-
# In order to be backwards compatible with legacy Workflows, which wrapped
|
61
|
-
# several values as VellumValue objects, we use the "value" key to return itself
|
62
|
-
return self
|
63
|
-
|
64
|
-
raise AttributeError(f"dict has no key: '{attr}'")
|
65
|
-
|
66
|
-
item = super().__getitem__(attr)
|
67
|
-
if not isinstance(item, DictWrapper) and not isinstance(item, ListWrapper):
|
68
|
-
self.__setattr__(attr, _clean_for_dict_wrapper(item))
|
69
|
-
|
70
|
-
return super().__getitem__(attr)
|
71
|
-
|
72
|
-
def __setattr__(self, name, value):
|
73
|
-
self[name] = value
|
74
|
-
|
75
|
-
|
76
|
-
def _clean_for_dict_wrapper(obj):
|
77
|
-
if isinstance(obj, dict):
|
78
|
-
wrapped = DictWrapper(obj)
|
79
|
-
for key in wrapped:
|
80
|
-
wrapped[key] = _clean_for_dict_wrapper(wrapped[key])
|
81
|
-
|
82
|
-
return wrapped
|
83
|
-
|
84
|
-
elif isinstance(obj, list):
|
85
|
-
return ListWrapper(map(lambda item: _clean_for_dict_wrapper(item), obj))
|
86
|
-
|
87
|
-
return obj
|
88
|
-
|
89
|
-
|
90
39
|
def run_code_inline(
|
91
40
|
code: str,
|
92
41
|
inputs: EntityInputsInterface,
|
@@ -107,12 +56,12 @@ def run_code_inline(
|
|
107
56
|
(
|
108
57
|
item.model_dump()
|
109
58
|
if isinstance(item, BaseModel)
|
110
|
-
else
|
59
|
+
else clean_for_dict_wrapper(item) if isinstance(item, (dict, list, str)) else item
|
111
60
|
)
|
112
61
|
for item in value
|
113
62
|
]
|
114
63
|
)
|
115
|
-
return
|
64
|
+
return clean_for_dict_wrapper(value)
|
116
65
|
|
117
66
|
exec_globals = {
|
118
67
|
"__arg__inputs": {name: wrap_value(value) for name, value in inputs.items()},
|
@@ -9,7 +9,6 @@ from vellum import (
|
|
9
9
|
PromptOutput,
|
10
10
|
StringVellumValue,
|
11
11
|
)
|
12
|
-
from vellum.workflows.constants import OMIT
|
13
12
|
from vellum.workflows.inputs import BaseInputs
|
14
13
|
from vellum.workflows.nodes import PromptDeploymentNode
|
15
14
|
from vellum.workflows.state import BaseState
|
@@ -66,14 +65,14 @@ def test_text_prompt_deployment_node__basic(vellum_client):
|
|
66
65
|
|
67
66
|
# AND we should have made the expected call to stream the prompt execution
|
68
67
|
vellum_client.execute_prompt_stream.assert_called_once_with(
|
69
|
-
expand_meta=
|
70
|
-
expand_raw=
|
71
|
-
external_id=
|
68
|
+
expand_meta=None,
|
69
|
+
expand_raw=None,
|
70
|
+
external_id=None,
|
72
71
|
inputs=[],
|
73
|
-
metadata=
|
72
|
+
metadata=None,
|
74
73
|
prompt_deployment_id=None,
|
75
74
|
prompt_deployment_name="my-deployment",
|
76
|
-
raw_overrides=
|
75
|
+
raw_overrides=None,
|
77
76
|
release_tag="LATEST",
|
78
77
|
request_options={
|
79
78
|
"additional_body_parameters": {"execution_context": {"parent_context": None, "trace_id": mock.ANY}}
|
vellum/workflows/nodes/utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from functools import cache
|
2
|
+
import inspect
|
2
3
|
import json
|
3
4
|
import sys
|
4
5
|
from types import ModuleType
|
@@ -14,6 +15,7 @@ from vellum.workflows.nodes import BaseNode
|
|
14
15
|
from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
15
16
|
from vellum.workflows.ports.port import Port
|
16
17
|
from vellum.workflows.state.base import BaseState
|
18
|
+
from vellum.workflows.types.code_execution_node_wrappers import StringValueWrapper
|
17
19
|
from vellum.workflows.types.core import Json
|
18
20
|
from vellum.workflows.types.generics import NodeType
|
19
21
|
|
@@ -176,6 +178,8 @@ def parse_type_from_str(result_as_str: str, output_type: Any) -> Any:
|
|
176
178
|
|
177
179
|
|
178
180
|
def _get_type_name(obj: Any) -> str:
|
181
|
+
if inspect.isclass(obj) and issubclass(obj, StringValueWrapper):
|
182
|
+
return "str"
|
179
183
|
if isinstance(obj, type):
|
180
184
|
return obj.__name__
|
181
185
|
|
vellum/workflows/ports/port.py
CHANGED
@@ -7,7 +7,7 @@ from vellum.workflows.descriptors.base import BaseDescriptor
|
|
7
7
|
from vellum.workflows.descriptors.exceptions import InvalidExpressionException
|
8
8
|
from vellum.workflows.edges.edge import Edge
|
9
9
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
10
|
-
from vellum.workflows.exceptions import NodeException
|
10
|
+
from vellum.workflows.exceptions import NodeException, WorkflowInitializationException
|
11
11
|
from vellum.workflows.graph import Graph, GraphTarget
|
12
12
|
from vellum.workflows.state.base import BaseState
|
13
13
|
from vellum.workflows.types.core import ConditionType
|
@@ -73,11 +73,11 @@ class Port:
|
|
73
73
|
return Graph.from_edge(edge)
|
74
74
|
|
75
75
|
@staticmethod
|
76
|
-
def on_if(condition: BaseDescriptor, fork_state: bool = False)
|
76
|
+
def on_if(condition: Optional[BaseDescriptor] = None, fork_state: bool = False):
|
77
77
|
return Port(condition=condition, condition_type=ConditionType.IF, fork_state=fork_state)
|
78
78
|
|
79
79
|
@staticmethod
|
80
|
-
def on_elif(condition: BaseDescriptor, fork_state: bool = False) -> "Port":
|
80
|
+
def on_elif(condition: Optional[BaseDescriptor] = None, fork_state: bool = False) -> "Port":
|
81
81
|
return Port(condition=condition, condition_type=ConditionType.ELIF, fork_state=fork_state)
|
82
82
|
|
83
83
|
@staticmethod
|
@@ -107,3 +107,13 @@ class Port:
|
|
107
107
|
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
|
108
108
|
) -> core_schema.CoreSchema:
|
109
109
|
return core_schema.is_instance_schema(cls)
|
110
|
+
|
111
|
+
def validate(self):
|
112
|
+
if (
|
113
|
+
not self.default
|
114
|
+
and self._condition_type in (ConditionType.IF, ConditionType.ELIF)
|
115
|
+
and self._condition is None
|
116
|
+
):
|
117
|
+
raise WorkflowInitializationException(
|
118
|
+
f"Class {self.node_class.__name__}'s {self.name} should have a defined condition and cannot be empty."
|
119
|
+
)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
class StringValueWrapper(str):
|
2
|
+
def __getitem__(self, key):
|
3
|
+
if key == "value":
|
4
|
+
return self
|
5
|
+
raise KeyError(key)
|
6
|
+
|
7
|
+
def __getattr__(self, attr):
|
8
|
+
if attr == "value":
|
9
|
+
return self
|
10
|
+
raise AttributeError(f"'str' object has no attribute '{attr}'")
|
11
|
+
|
12
|
+
|
13
|
+
class ListWrapper(list):
|
14
|
+
def __getitem__(self, key):
|
15
|
+
item = super().__getitem__(key)
|
16
|
+
if not isinstance(item, DictWrapper) and not isinstance(item, ListWrapper):
|
17
|
+
self.__setitem__(key, clean_for_dict_wrapper(item))
|
18
|
+
|
19
|
+
return super().__getitem__(key)
|
20
|
+
|
21
|
+
|
22
|
+
class DictWrapper(dict):
|
23
|
+
"""
|
24
|
+
This wraps a dict object to make it behave basically the same as a standard javascript object
|
25
|
+
and enables us to use vellum types here without a shared library since we don't actually
|
26
|
+
typecheck things here.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __getitem__(self, key):
|
30
|
+
return self.__getattr__(key)
|
31
|
+
|
32
|
+
def __getattr__(self, attr):
|
33
|
+
if attr not in self:
|
34
|
+
if attr == "value":
|
35
|
+
# In order to be backwards compatible with legacy Workflows, which wrapped
|
36
|
+
# several values as VellumValue objects, we use the "value" key to return itself
|
37
|
+
return self
|
38
|
+
|
39
|
+
raise AttributeError(f"dict has no key: '{attr}'")
|
40
|
+
|
41
|
+
item = super().__getitem__(attr)
|
42
|
+
if not isinstance(item, DictWrapper) and not isinstance(item, ListWrapper):
|
43
|
+
self.__setattr__(attr, clean_for_dict_wrapper(item))
|
44
|
+
|
45
|
+
return super().__getitem__(attr)
|
46
|
+
|
47
|
+
def __setattr__(self, name, value):
|
48
|
+
self[name] = value
|
49
|
+
|
50
|
+
|
51
|
+
def clean_for_dict_wrapper(obj):
|
52
|
+
if isinstance(obj, dict):
|
53
|
+
wrapped = DictWrapper(obj)
|
54
|
+
for key in wrapped:
|
55
|
+
wrapped[key] = clean_for_dict_wrapper(wrapped[key])
|
56
|
+
|
57
|
+
return wrapped
|
58
|
+
|
59
|
+
elif isinstance(obj, list):
|
60
|
+
return ListWrapper(map(lambda item: clean_for_dict_wrapper(item), obj))
|
61
|
+
elif isinstance(obj, str):
|
62
|
+
return StringValueWrapper(obj)
|
63
|
+
|
64
|
+
return obj
|
@@ -83,9 +83,9 @@ def test_infer_types(cls, attr_name, expected_type):
|
|
83
83
|
@pytest.mark.parametrize(
|
84
84
|
"cls, expected_attr_names",
|
85
85
|
[
|
86
|
-
(ExampleClass,
|
87
|
-
(ExampleGenericClass,
|
88
|
-
(ExampleInheritedClass,
|
86
|
+
(ExampleClass, ["beta", "epsilon", "alpha", "gamma", "zeta", "eta", "kappa", "mu"]),
|
87
|
+
(ExampleGenericClass, ["delta"]),
|
88
|
+
(ExampleInheritedClass, ["theta", "beta", "epsilon", "alpha", "gamma", "zeta", "eta", "kappa", "mu"]),
|
89
89
|
],
|
90
90
|
)
|
91
91
|
def test_class_attr_names(cls, expected_attr_names):
|
vellum/workflows/types/utils.py
CHANGED
@@ -7,6 +7,7 @@ from typing import (
|
|
7
7
|
ClassVar,
|
8
8
|
Dict,
|
9
9
|
Generic,
|
10
|
+
List,
|
10
11
|
Optional,
|
11
12
|
Set,
|
12
13
|
Tuple,
|
@@ -101,22 +102,42 @@ def infer_types(object_: Type, attr_name: str, localns: Optional[Dict[str, Any]]
|
|
101
102
|
)
|
102
103
|
|
103
104
|
|
104
|
-
def get_class_attr_names(cls: Type) ->
|
105
|
-
#
|
106
|
-
|
105
|
+
def get_class_attr_names(cls: Type) -> List[str]:
|
106
|
+
# make sure we don't duplicate attributes
|
107
|
+
collected_attributes: Set[str] = set()
|
107
108
|
|
108
|
-
#
|
109
|
-
|
109
|
+
# we want to preserve the order of attributes on each class
|
110
|
+
ordered_attr_names: List[str] = []
|
110
111
|
|
111
|
-
for base in
|
112
|
+
for base in cls.__mro__:
|
113
|
+
# gets attributes declared `foo = 1`
|
114
|
+
for class_attribute in vars(base).keys():
|
115
|
+
if class_attribute in collected_attributes:
|
116
|
+
continue
|
117
|
+
|
118
|
+
if class_attribute.startswith("_"):
|
119
|
+
continue
|
120
|
+
|
121
|
+
collected_attributes.add(class_attribute)
|
122
|
+
ordered_attr_names.append(class_attribute)
|
123
|
+
|
124
|
+
# gets type-annotated attributes `foo: int`
|
112
125
|
ann = base.__dict__.get("__annotations__", {})
|
113
|
-
|
126
|
+
for attr_name in ann.keys():
|
127
|
+
if not isinstance(attr_name, str):
|
128
|
+
continue
|
129
|
+
|
130
|
+
if attr_name in collected_attributes:
|
131
|
+
continue
|
132
|
+
|
133
|
+
if attr_name.startswith("_"):
|
134
|
+
continue
|
114
135
|
|
115
|
-
|
116
|
-
|
136
|
+
collected_attributes.add(attr_name)
|
137
|
+
ordered_attr_names.append(attr_name)
|
117
138
|
|
118
139
|
# combine and filter out private attributes
|
119
|
-
return
|
140
|
+
return ordered_attr_names
|
120
141
|
|
121
142
|
|
122
143
|
def deepcopy_with_exclusions(
|
@@ -1,22 +1,34 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Optional
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
from vellum import Vellum, VellumEnvironment
|
5
5
|
|
6
6
|
|
7
|
-
def create_vellum_client(api_key: Optional[str] = None) -> Vellum:
|
7
|
+
def create_vellum_client(api_key: Optional[str] = None, api_url: Optional[str] = None) -> Vellum:
|
8
8
|
if api_key is None:
|
9
9
|
api_key = os.getenv("VELLUM_API_KEY", default="")
|
10
10
|
|
11
11
|
return Vellum(
|
12
12
|
api_key=api_key,
|
13
|
-
environment=create_vellum_environment(),
|
13
|
+
environment=create_vellum_environment(api_url),
|
14
14
|
)
|
15
15
|
|
16
16
|
|
17
|
-
def create_vellum_environment() -> VellumEnvironment:
|
17
|
+
def create_vellum_environment(api_url: Optional[str] = None) -> VellumEnvironment:
|
18
18
|
return VellumEnvironment(
|
19
|
-
default=
|
20
|
-
documents=
|
21
|
-
predict=
|
19
|
+
default=_resolve_env([api_url, "VELLUM_DEFAULT_API_URL", "VELLUM_API_URL"], "https://api.vellum.ai"),
|
20
|
+
documents=_resolve_env([api_url, "VELLUM_DOCUMENTS_API_URL", "VELLUM_API_URL"], "https://documents.vellum.ai"),
|
21
|
+
predict=_resolve_env([api_url, "VELLUM_PREDICT_API_URL", "VELLUM_API_URL"], "https://predict.vellum.ai"),
|
22
22
|
)
|
23
|
+
|
24
|
+
|
25
|
+
def _resolve_env(names: List[Optional[str]], default: str = "") -> str:
|
26
|
+
for name in names:
|
27
|
+
if not name:
|
28
|
+
continue
|
29
|
+
|
30
|
+
value = os.getenv(name)
|
31
|
+
if value:
|
32
|
+
return value
|
33
|
+
|
34
|
+
return default
|