vellum-ai 1.3.3__py3-none-any.whl → 1.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/types/function_definition.py +5 -0
- vellum/client/types/scenario_input_audio_variable_value.py +1 -1
- vellum/client/types/scenario_input_document_variable_value.py +1 -1
- vellum/client/types/scenario_input_image_variable_value.py +1 -1
- vellum/client/types/scenario_input_video_variable_value.py +1 -1
- vellum/workflows/events/node.py +1 -1
- vellum/workflows/events/tests/test_event.py +1 -1
- vellum/workflows/events/workflow.py +1 -1
- vellum/workflows/nodes/bases/base.py +2 -5
- vellum/workflows/nodes/core/map_node/node.py +8 -1
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +2 -2
- vellum/workflows/nodes/displayable/guardrail_node/node.py +8 -3
- vellum/workflows/nodes/displayable/tool_calling_node/node.py +4 -0
- vellum/workflows/nodes/displayable/tool_calling_node/utils.py +17 -2
- vellum/workflows/outputs/base.py +11 -11
- vellum/workflows/references/output.py +3 -5
- vellum/workflows/resolvers/resolver.py +18 -2
- vellum/workflows/resolvers/tests/test_resolver.py +121 -0
- vellum/workflows/runner/runner.py +17 -17
- vellum/workflows/state/encoder.py +0 -37
- vellum/workflows/utils/functions.py +35 -0
- vellum/workflows/workflows/base.py +9 -1
- {vellum_ai-1.3.3.dist-info → vellum_ai-1.3.5.dist-info}/METADATA +1 -1
- {vellum_ai-1.3.3.dist-info → vellum_ai-1.3.5.dist-info}/RECORD +39 -37
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +18 -2
- vellum_ee/workflows/display/tests/test_base_workflow_display.py +99 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_parent_input.py +85 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +2 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_final_output_node_map_reference_serialization.py +88 -0
- vellum_ee/workflows/display/utils/events.py +1 -0
- vellum_ee/workflows/display/utils/expressions.py +56 -0
- vellum_ee/workflows/display/utils/tests/test_events.py +11 -1
- vellum_ee/workflows/display/utils/vellum.py +3 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +41 -27
- {vellum_ai-1.3.3.dist-info → vellum_ai-1.3.5.dist-info}/LICENSE +0 -0
- {vellum_ai-1.3.3.dist-info → vellum_ai-1.3.5.dist-info}/WHEEL +0 -0
- {vellum_ai-1.3.3.dist-info → vellum_ai-1.3.5.dist-info}/entry_points.txt +0 -0
@@ -27,10 +27,10 @@ class BaseClientWrapper:
|
|
27
27
|
|
28
28
|
def get_headers(self) -> typing.Dict[str, str]:
|
29
29
|
headers: typing.Dict[str, str] = {
|
30
|
-
"User-Agent": "vellum-ai/1.3.
|
30
|
+
"User-Agent": "vellum-ai/1.3.5",
|
31
31
|
"X-Fern-Language": "Python",
|
32
32
|
"X-Fern-SDK-Name": "vellum-ai",
|
33
|
-
"X-Fern-SDK-Version": "1.3.
|
33
|
+
"X-Fern-SDK-Version": "1.3.5",
|
34
34
|
**(self.get_custom_headers() or {}),
|
35
35
|
}
|
36
36
|
if self._api_version is not None:
|
@@ -30,6 +30,11 @@ class FunctionDefinition(UniversalBaseModel):
|
|
30
30
|
An OpenAPI specification of parameters that are supported by this function.
|
31
31
|
"""
|
32
32
|
|
33
|
+
inputs: typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]] = pydantic.Field(default=None)
|
34
|
+
"""
|
35
|
+
Optional user defined input mappings for this function.
|
36
|
+
"""
|
37
|
+
|
33
38
|
forced: typing.Optional[bool] = pydantic.Field(default=None)
|
34
39
|
"""
|
35
40
|
Set this option to true to force the model to return a function call of this function.
|
@@ -9,7 +9,7 @@ from .vellum_audio import VellumAudio
|
|
9
9
|
|
10
10
|
class ScenarioInputAudioVariableValue(UniversalBaseModel):
|
11
11
|
type: typing.Literal["AUDIO"] = "AUDIO"
|
12
|
-
value: VellumAudio
|
12
|
+
value: typing.Optional[VellumAudio] = None
|
13
13
|
input_variable_id: str
|
14
14
|
|
15
15
|
if IS_PYDANTIC_V2:
|
@@ -9,7 +9,7 @@ from .vellum_document import VellumDocument
|
|
9
9
|
|
10
10
|
class ScenarioInputDocumentVariableValue(UniversalBaseModel):
|
11
11
|
type: typing.Literal["DOCUMENT"] = "DOCUMENT"
|
12
|
-
value: VellumDocument
|
12
|
+
value: typing.Optional[VellumDocument] = None
|
13
13
|
input_variable_id: str
|
14
14
|
|
15
15
|
if IS_PYDANTIC_V2:
|
@@ -9,7 +9,7 @@ from .vellum_image import VellumImage
|
|
9
9
|
|
10
10
|
class ScenarioInputImageVariableValue(UniversalBaseModel):
|
11
11
|
type: typing.Literal["IMAGE"] = "IMAGE"
|
12
|
-
value: VellumImage
|
12
|
+
value: typing.Optional[VellumImage] = None
|
13
13
|
input_variable_id: str
|
14
14
|
|
15
15
|
if IS_PYDANTIC_V2:
|
@@ -9,7 +9,7 @@ from .vellum_video import VellumVideo
|
|
9
9
|
|
10
10
|
class ScenarioInputVideoVariableValue(UniversalBaseModel):
|
11
11
|
type: typing.Literal["VIDEO"] = "VIDEO"
|
12
|
-
value: VellumVideo
|
12
|
+
value: typing.Optional[VellumVideo] = None
|
13
13
|
input_variable_id: str
|
14
14
|
|
15
15
|
if IS_PYDANTIC_V2:
|
vellum/workflows/events/node.py
CHANGED
@@ -141,7 +141,7 @@ class NodeExecutionFulfilledEvent(_BaseNodeEvent, Generic[OutputsType]):
|
|
141
141
|
|
142
142
|
class NodeExecutionRejectedBody(_BaseNodeExecutionBody):
|
143
143
|
error: WorkflowError
|
144
|
-
|
144
|
+
stacktrace: Optional[str] = None
|
145
145
|
|
146
146
|
|
147
147
|
class NodeExecutionRejectedEvent(_BaseNodeEvent):
|
@@ -156,7 +156,7 @@ class WorkflowExecutionFulfilledEvent(_BaseWorkflowEvent, Generic[OutputsType]):
|
|
156
156
|
|
157
157
|
class WorkflowExecutionRejectedBody(_BaseWorkflowExecutionBody):
|
158
158
|
error: WorkflowError
|
159
|
-
|
159
|
+
stacktrace: Optional[str] = None
|
160
160
|
|
161
161
|
|
162
162
|
class WorkflowExecutionRejectedEvent(_BaseWorkflowEvent):
|
@@ -120,7 +120,7 @@ class BaseNodeMeta(ABCMeta):
|
|
120
120
|
cls = super().__new__(mcs, name, bases, dct)
|
121
121
|
node_class = cast(Type["BaseNode"], cls)
|
122
122
|
|
123
|
-
node_class.Outputs.
|
123
|
+
node_class.Outputs.__parent_class__ = node_class
|
124
124
|
|
125
125
|
# Add cls to relevant nested classes, since python should've been doing this by default
|
126
126
|
for port in node_class.Ports:
|
@@ -270,11 +270,8 @@ class BaseNode(Generic[StateType], ABC, metaclass=BaseNodeMeta):
|
|
270
270
|
class ExternalInputs(BaseInputs):
|
271
271
|
__descriptor_class__ = ExternalInputReference
|
272
272
|
|
273
|
-
# TODO: Consider using metaclasses to prevent the need for users to specify that the
|
274
|
-
# "Outputs" class inherits from "BaseOutputs" and do so automatically.
|
275
|
-
# https://app.shortcut.com/vellum/story/4008/auto-inherit-basenodeoutputs-in-outputs-classes
|
276
273
|
class Outputs(BaseOutputs):
|
277
|
-
|
274
|
+
__parent_class__: Type["BaseNode"] = field(init=False)
|
278
275
|
|
279
276
|
class Ports(NodePorts):
|
280
277
|
default = Port(default=True)
|
@@ -27,6 +27,7 @@ from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
|
|
27
27
|
from vellum.workflows.nodes.utils import create_adornment
|
28
28
|
from vellum.workflows.outputs import BaseOutputs
|
29
29
|
from vellum.workflows.outputs.base import BaseOutput
|
30
|
+
from vellum.workflows.references.node import NodeReference
|
30
31
|
from vellum.workflows.references.output import OutputReference
|
31
32
|
from vellum.workflows.state.context import WorkflowContext
|
32
33
|
from vellum.workflows.types.generics import StateType
|
@@ -217,5 +218,11 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
|
|
217
218
|
# value: List[str]
|
218
219
|
outputs_class.__annotations__ = {**previous_annotations, reference.name: annotation}
|
219
220
|
|
220
|
-
|
221
|
+
subworkflow_class = cls.subworkflow.instance if isinstance(cls.subworkflow, NodeReference) else None
|
222
|
+
if subworkflow_class:
|
223
|
+
output_id = subworkflow_class.__output_ids__.get(reference.name) or uuid4_from_hash(
|
224
|
+
f"{cls.__id__}|{reference.name}"
|
225
|
+
)
|
226
|
+
else:
|
227
|
+
output_id = uuid4_from_hash(f"{cls.__id__}|{reference.name}")
|
221
228
|
cls.__output_ids__[reference.name] = output_id
|
@@ -100,7 +100,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
100
100
|
execution_context = get_execution_context()
|
101
101
|
request_options = self.request_options or RequestOptions()
|
102
102
|
|
103
|
-
processed_parameters = self.
|
103
|
+
processed_parameters = self.process_parameters(self.parameters)
|
104
104
|
|
105
105
|
request_options["additional_body_parameters"] = {
|
106
106
|
"execution_context": execution_context.model_dump(mode="json"),
|
@@ -300,7 +300,7 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
|
|
300
300
|
|
301
301
|
return input_variables, input_values
|
302
302
|
|
303
|
-
def
|
303
|
+
def process_parameters(self, parameters: PromptParameters) -> PromptParameters:
|
304
304
|
"""
|
305
305
|
Process parameters to recursively convert any Pydantic models to JSON schema dictionaries.
|
306
306
|
"""
|
@@ -8,7 +8,6 @@ from vellum.workflows.constants import LATEST_RELEASE_TAG
|
|
8
8
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
9
9
|
from vellum.workflows.exceptions import NodeException
|
10
10
|
from vellum.workflows.nodes.bases import BaseNode
|
11
|
-
from vellum.workflows.outputs.base import BaseOutputs
|
12
11
|
from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
|
13
12
|
from vellum.workflows.types.generics import StateType
|
14
13
|
|
@@ -33,7 +32,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
33
32
|
class Trigger(BaseNode.Trigger):
|
34
33
|
merge_behavior = MergeBehavior.AWAIT_ANY
|
35
34
|
|
36
|
-
class Outputs(
|
35
|
+
class Outputs(BaseNode.Outputs):
|
37
36
|
score: float
|
38
37
|
normalized_score: Optional[float]
|
39
38
|
log: Optional[str]
|
@@ -98,7 +97,13 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
|
|
98
97
|
else:
|
99
98
|
reason = None
|
100
99
|
|
101
|
-
return self.Outputs(
|
100
|
+
return self.Outputs(
|
101
|
+
score=score,
|
102
|
+
normalized_score=normalized_score,
|
103
|
+
log=log,
|
104
|
+
reason=reason,
|
105
|
+
**metric_outputs, # type: ignore [arg-type]
|
106
|
+
)
|
102
107
|
|
103
108
|
def _compile_metric_inputs(self) -> List[MetricDefinitionInput]:
|
104
109
|
# TODO: We may want to consolidate with prompt deployment input compilation
|
@@ -138,6 +138,9 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
|
|
138
138
|
)
|
139
139
|
|
140
140
|
def _build_graph(self) -> None:
|
141
|
+
# Get the process_parameters method if it exists on this class
|
142
|
+
process_parameters_method = getattr(self.__class__, "process_parameters", None)
|
143
|
+
|
141
144
|
self.tool_prompt_node = create_tool_prompt_node(
|
142
145
|
ml_model=self.ml_model,
|
143
146
|
blocks=self.blocks,
|
@@ -145,6 +148,7 @@ class ToolCallingNode(BaseNode[StateType], Generic[StateType]):
|
|
145
148
|
prompt_inputs=self.prompt_inputs,
|
146
149
|
parameters=self.parameters,
|
147
150
|
max_prompt_iterations=self.max_prompt_iterations,
|
151
|
+
process_parameters_method=process_parameters_method,
|
148
152
|
)
|
149
153
|
|
150
154
|
# Create the router node (handles routing logic only)
|
@@ -329,6 +329,7 @@ def create_tool_prompt_node(
|
|
329
329
|
prompt_inputs: Optional[EntityInputsInterface],
|
330
330
|
parameters: PromptParameters,
|
331
331
|
max_prompt_iterations: Optional[int] = None,
|
332
|
+
process_parameters_method: Optional[Callable] = None,
|
332
333
|
) -> Type[ToolPromptNode]:
|
333
334
|
if functions and len(functions) > 0:
|
334
335
|
prompt_functions: List[Union[Tool, FunctionDefinition]] = []
|
@@ -398,6 +399,7 @@ def create_tool_prompt_node(
|
|
398
399
|
"prompt_inputs": node_prompt_inputs,
|
399
400
|
"parameters": parameters,
|
400
401
|
"max_prompt_iterations": max_prompt_iterations,
|
402
|
+
**({"process_parameters": process_parameters_method} if process_parameters_method is not None else {}),
|
401
403
|
"__module__": __name__,
|
402
404
|
},
|
403
405
|
),
|
@@ -520,12 +522,25 @@ def create_function_node(
|
|
520
522
|
},
|
521
523
|
)
|
522
524
|
else:
|
523
|
-
|
525
|
+
|
526
|
+
def create_function_wrapper(func):
|
527
|
+
def wrapper(self, **kwargs):
|
528
|
+
merged_kwargs = kwargs.copy()
|
529
|
+
inputs = getattr(func, "__vellum_inputs__", {})
|
530
|
+
if inputs:
|
531
|
+
for param_name, param_ref in inputs.items():
|
532
|
+
resolved_value = param_ref.resolve(self.state)
|
533
|
+
merged_kwargs[param_name] = resolved_value
|
534
|
+
|
535
|
+
return func(**merged_kwargs)
|
536
|
+
|
537
|
+
return wrapper
|
538
|
+
|
524
539
|
node = type(
|
525
540
|
f"FunctionNode_{function.__name__}",
|
526
541
|
(FunctionNode,),
|
527
542
|
{
|
528
|
-
"function_definition":
|
543
|
+
"function_definition": create_function_wrapper(function),
|
529
544
|
"function_call_output": tool_prompt_node.Outputs.results,
|
530
545
|
"__module__": __name__,
|
531
546
|
},
|
vellum/workflows/outputs/base.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
|
+
from dataclasses import field
|
1
2
|
import inspect
|
2
|
-
from typing import
|
3
|
+
from typing import Any, Generic, Iterator, Set, Tuple, Type, TypeVar, Union, cast
|
3
4
|
from typing_extensions import dataclass_transform
|
4
5
|
|
5
6
|
from pydantic import GetCoreSchemaHandler
|
@@ -13,9 +14,6 @@ from vellum.workflows.references.output import OutputReference
|
|
13
14
|
from vellum.workflows.types.generics import is_node_instance
|
14
15
|
from vellum.workflows.types.utils import get_class_attr_names, infer_types
|
15
16
|
|
16
|
-
if TYPE_CHECKING:
|
17
|
-
from vellum.workflows.nodes.bases.base import BaseNode
|
18
|
-
|
19
17
|
_Delta = TypeVar("_Delta")
|
20
18
|
_Accumulated = TypeVar("_Accumulated")
|
21
19
|
|
@@ -112,18 +110,16 @@ class _BaseOutputsMeta(type):
|
|
112
110
|
if not cls.__qualname__.endswith(".Outputs") or not other.__qualname__.endswith(".Outputs"):
|
113
111
|
return super().__eq__(other)
|
114
112
|
|
115
|
-
self_outputs_class = cast(Type["
|
116
|
-
other_outputs_class = cast(Type["
|
113
|
+
self_outputs_class = cast(Type["BaseOutputs"], cls)
|
114
|
+
other_outputs_class = cast(Type["BaseOutputs"], other)
|
117
115
|
|
118
|
-
if not hasattr(self_outputs_class, "
|
116
|
+
if not hasattr(self_outputs_class, "__parent_class__") or not hasattr(other_outputs_class, "__parent_class__"):
|
119
117
|
return super().__eq__(other)
|
120
118
|
|
121
|
-
if self_outputs_class.
|
119
|
+
if self_outputs_class.__parent_class__ is None or other_outputs_class.__parent_class__ is None:
|
122
120
|
return super().__eq__(other)
|
123
121
|
|
124
|
-
return
|
125
|
-
other_outputs_class._node_class, "__qualname__"
|
126
|
-
)
|
122
|
+
return self_outputs_class.__parent_class__.__qualname__ == other_outputs_class.__parent_class__.__qualname__
|
127
123
|
|
128
124
|
def __setattr__(cls, name: str, value: Any) -> None:
|
129
125
|
if isinstance(value, OutputReference):
|
@@ -187,6 +183,10 @@ class _BaseOutputsMeta(type):
|
|
187
183
|
|
188
184
|
|
189
185
|
class BaseOutputs(metaclass=_BaseOutputsMeta):
|
186
|
+
# TODO: Uncomment once we figure out why this causes a failure in `infer_types`
|
187
|
+
# __parent_class__: Type[Union["BaseNode", "BaseWorkflow"]] = field(init=False)
|
188
|
+
__parent_class__: Type = field(init=False)
|
189
|
+
|
190
190
|
def __init__(self, **kwargs: Any) -> None:
|
191
191
|
declared_fields = {descriptor.name for descriptor in self.__class__}
|
192
192
|
provided_fields = set(kwargs.keys())
|
@@ -35,13 +35,11 @@ class OutputReference(BaseDescriptor[_OutputType], Generic[_OutputType]):
|
|
35
35
|
|
36
36
|
@cached_property
|
37
37
|
def id(self) -> UUID:
|
38
|
-
|
39
|
-
|
40
|
-
node_class = getattr(self._outputs_class, "_node_class", None)
|
41
|
-
if not node_class:
|
38
|
+
parent_class = self._outputs_class.__parent_class__
|
39
|
+
if not parent_class:
|
42
40
|
return uuid4()
|
43
41
|
|
44
|
-
output_ids = getattr(
|
42
|
+
output_ids = getattr(parent_class, "__output_ids__", {})
|
45
43
|
if not isinstance(output_ids, dict):
|
46
44
|
return uuid4()
|
47
45
|
|
@@ -1,10 +1,11 @@
|
|
1
1
|
import logging
|
2
2
|
from uuid import UUID
|
3
|
-
from typing import Iterator, List, Optional, Tuple, Union
|
3
|
+
from typing import Iterator, List, Optional, Tuple, Type, Union
|
4
4
|
|
5
5
|
from vellum.client.types.vellum_span import VellumSpan
|
6
6
|
from vellum.client.types.workflow_execution_initiated_event import WorkflowExecutionInitiatedEvent
|
7
7
|
from vellum.workflows.events.workflow import WorkflowEvent
|
8
|
+
from vellum.workflows.nodes.utils import cast_to_output_type
|
8
9
|
from vellum.workflows.resolvers.base import BaseWorkflowResolver
|
9
10
|
from vellum.workflows.resolvers.types import LoadStateResult
|
10
11
|
from vellum.workflows.state.base import BaseState
|
@@ -51,6 +52,21 @@ class VellumResolver(BaseWorkflowResolver):
|
|
51
52
|
|
52
53
|
return previous_trace_id, root_trace_id, previous_span_id, root_span_id
|
53
54
|
|
55
|
+
def _deserialize_state(self, state_data: dict, state_class: Type[BaseState]) -> BaseState:
|
56
|
+
"""Deserialize state data with proper type conversion for complex types like List[ChatMessage]."""
|
57
|
+
converted_data = {}
|
58
|
+
|
59
|
+
annotations = getattr(state_class, "__annotations__", {})
|
60
|
+
|
61
|
+
for field_name, field_value in state_data.items():
|
62
|
+
if field_name in annotations:
|
63
|
+
field_type = annotations[field_name]
|
64
|
+
converted_data[field_name] = cast_to_output_type(field_value, field_type)
|
65
|
+
else:
|
66
|
+
converted_data[field_name] = field_value
|
67
|
+
|
68
|
+
return state_class(**converted_data)
|
69
|
+
|
54
70
|
def load_state(self, previous_execution_id: Optional[Union[UUID, str]] = None) -> Optional[LoadStateResult]:
|
55
71
|
if isinstance(previous_execution_id, UUID):
|
56
72
|
previous_execution_id = str(previous_execution_id)
|
@@ -83,7 +99,7 @@ class VellumResolver(BaseWorkflowResolver):
|
|
83
99
|
|
84
100
|
if self._workflow_class:
|
85
101
|
state_class = self._workflow_class.get_state_class()
|
86
|
-
state =
|
102
|
+
state = self._deserialize_state(response.state, state_class)
|
87
103
|
else:
|
88
104
|
logger.warning("No workflow class registered, falling back to BaseState")
|
89
105
|
state = BaseState(**response.state)
|
@@ -1,7 +1,9 @@
|
|
1
1
|
from datetime import datetime
|
2
2
|
from unittest.mock import Mock
|
3
3
|
from uuid import uuid4
|
4
|
+
from typing import List
|
4
5
|
|
6
|
+
from vellum import ChatMessage
|
5
7
|
from vellum.client.types.span_link import SpanLink
|
6
8
|
from vellum.client.types.vellum_code_resource_definition import VellumCodeResourceDefinition
|
7
9
|
from vellum.client.types.workflow_execution_detail import WorkflowExecutionDetail
|
@@ -129,3 +131,122 @@ def test_load_state_with_context_success():
|
|
129
131
|
mock_client.workflow_executions.retrieve_workflow_execution_detail.assert_called_once_with(
|
130
132
|
execution_id=str(execution_id)
|
131
133
|
)
|
134
|
+
|
135
|
+
|
136
|
+
def test_load_state_with_chat_message_list():
|
137
|
+
"""Test load_state successfully loads state with chat_history containing ChatMessage list."""
|
138
|
+
resolver = VellumResolver()
|
139
|
+
execution_id = uuid4()
|
140
|
+
root_execution_id = uuid4()
|
141
|
+
|
142
|
+
class TestStateWithChatHistory(BaseState):
|
143
|
+
test_key: str = "test_value"
|
144
|
+
chat_history: List[ChatMessage] = []
|
145
|
+
|
146
|
+
class TestWorkflow(BaseWorkflow[BaseInputs, TestStateWithChatHistory]):
|
147
|
+
pass
|
148
|
+
|
149
|
+
# GIVEN a state dictionary with chat_history containing ChatMessage objects
|
150
|
+
prev_id = str(uuid4())
|
151
|
+
prev_span_id = str(uuid4())
|
152
|
+
state_dict = {
|
153
|
+
"test_key": "test_value",
|
154
|
+
"chat_history": [
|
155
|
+
{"role": "USER", "text": "Hello, how are you?"},
|
156
|
+
{"role": "ASSISTANT", "text": "I'm doing well, thank you!"},
|
157
|
+
{"role": "USER", "text": "What can you help me with?"},
|
158
|
+
],
|
159
|
+
"meta": {
|
160
|
+
"workflow_definition": "MockWorkflow",
|
161
|
+
"id": prev_id,
|
162
|
+
"span_id": prev_span_id,
|
163
|
+
"updated_ts": datetime.now().isoformat(),
|
164
|
+
"workflow_inputs": BaseInputs(),
|
165
|
+
"external_inputs": {},
|
166
|
+
"node_outputs": {},
|
167
|
+
"node_execution_cache": NodeExecutionCache(),
|
168
|
+
"parent": None,
|
169
|
+
},
|
170
|
+
}
|
171
|
+
|
172
|
+
mock_workflow_definition = VellumCodeResourceDefinition(
|
173
|
+
name="TestWorkflow", module=["test", "module"], id=str(uuid4())
|
174
|
+
)
|
175
|
+
|
176
|
+
mock_body = WorkflowExecutionInitiatedBody(workflow_definition=mock_workflow_definition, inputs={})
|
177
|
+
|
178
|
+
previous_trace_id = str(uuid4())
|
179
|
+
root_trace_id = str(uuid4())
|
180
|
+
|
181
|
+
previous_invocation = WorkflowExecutionInitiatedEvent(
|
182
|
+
id=str(uuid4()),
|
183
|
+
timestamp=datetime.now(),
|
184
|
+
trace_id=previous_trace_id,
|
185
|
+
span_id=str(execution_id),
|
186
|
+
body=mock_body,
|
187
|
+
links=[
|
188
|
+
SpanLink(
|
189
|
+
trace_id=previous_trace_id,
|
190
|
+
type="PREVIOUS_SPAN",
|
191
|
+
span_context=WorkflowParentContext(workflow_definition=mock_workflow_definition, span_id=str(uuid4())),
|
192
|
+
),
|
193
|
+
SpanLink(
|
194
|
+
trace_id=root_trace_id,
|
195
|
+
type="ROOT_SPAN",
|
196
|
+
span_context=WorkflowParentContext(
|
197
|
+
workflow_definition=mock_workflow_definition, span_id=str(root_execution_id)
|
198
|
+
),
|
199
|
+
),
|
200
|
+
],
|
201
|
+
)
|
202
|
+
|
203
|
+
root_invocation = WorkflowExecutionInitiatedEvent(
|
204
|
+
id=str(uuid4()),
|
205
|
+
timestamp=datetime.now(),
|
206
|
+
trace_id=root_trace_id,
|
207
|
+
span_id=str(root_execution_id),
|
208
|
+
body=mock_body,
|
209
|
+
links=None,
|
210
|
+
)
|
211
|
+
|
212
|
+
mock_span = WorkflowExecutionSpan(
|
213
|
+
span_id=str(execution_id),
|
214
|
+
start_ts=datetime.now(),
|
215
|
+
end_ts=datetime.now(),
|
216
|
+
attributes=WorkflowExecutionSpanAttributes(label="Test Workflow", workflow_id=str(uuid4())),
|
217
|
+
events=[previous_invocation, root_invocation],
|
218
|
+
)
|
219
|
+
|
220
|
+
mock_response = WorkflowExecutionDetail(
|
221
|
+
span_id="test-span-id", start=datetime.now(), inputs=[], outputs=[], spans=[mock_span], state=state_dict
|
222
|
+
)
|
223
|
+
|
224
|
+
mock_client = Mock()
|
225
|
+
mock_client.workflow_executions.retrieve_workflow_execution_detail.return_value = mock_response
|
226
|
+
|
227
|
+
# AND context with the test workflow class is set up
|
228
|
+
context = WorkflowContext(vellum_client=mock_client)
|
229
|
+
TestWorkflow(context=context, resolvers=[resolver])
|
230
|
+
|
231
|
+
# WHEN load_state is called
|
232
|
+
result = resolver.load_state(previous_execution_id=execution_id)
|
233
|
+
|
234
|
+
# THEN should return LoadStateResult with state containing chat_history
|
235
|
+
assert isinstance(result, LoadStateResult)
|
236
|
+
assert result.state is not None
|
237
|
+
assert isinstance(result.state, TestStateWithChatHistory)
|
238
|
+
assert result.state.test_key == "test_value"
|
239
|
+
|
240
|
+
# AND the chat_history should be properly deserialized as ChatMessage objects
|
241
|
+
assert len(result.state.chat_history) == 3
|
242
|
+
assert all(isinstance(msg, ChatMessage) for msg in result.state.chat_history)
|
243
|
+
assert result.state.chat_history[0].role == "USER"
|
244
|
+
assert result.state.chat_history[0].text == "Hello, how are you?"
|
245
|
+
assert result.state.chat_history[1].role == "ASSISTANT"
|
246
|
+
assert result.state.chat_history[1].text == "I'm doing well, thank you!"
|
247
|
+
assert result.state.chat_history[2].role == "USER"
|
248
|
+
assert result.state.chat_history[2].text == "What can you help me with?"
|
249
|
+
|
250
|
+
mock_client.workflow_executions.retrieve_workflow_execution_detail.assert_called_once_with(
|
251
|
+
execution_id=str(execution_id)
|
252
|
+
)
|
@@ -404,7 +404,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
404
404
|
)
|
405
405
|
except NodeException as e:
|
406
406
|
logger.info(e)
|
407
|
-
|
407
|
+
captured_stacktrace = traceback.format_exc()
|
408
408
|
|
409
409
|
self._workflow_event_inner_queue.put(
|
410
410
|
NodeExecutionRejectedEvent(
|
@@ -413,14 +413,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
413
413
|
body=NodeExecutionRejectedBody(
|
414
414
|
node_definition=node.__class__,
|
415
415
|
error=e.error,
|
416
|
-
|
416
|
+
stacktrace=captured_stacktrace,
|
417
417
|
),
|
418
418
|
parent=execution.parent_context,
|
419
419
|
)
|
420
420
|
)
|
421
421
|
except WorkflowInitializationException as e:
|
422
422
|
logger.info(e)
|
423
|
-
|
423
|
+
captured_stacktrace = traceback.format_exc()
|
424
424
|
self._workflow_event_inner_queue.put(
|
425
425
|
NodeExecutionRejectedEvent(
|
426
426
|
trace_id=execution.trace_id,
|
@@ -428,7 +428,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
428
428
|
body=NodeExecutionRejectedBody(
|
429
429
|
node_definition=node.__class__,
|
430
430
|
error=e.error,
|
431
|
-
|
431
|
+
stacktrace=captured_stacktrace,
|
432
432
|
),
|
433
433
|
parent=execution.parent_context,
|
434
434
|
)
|
@@ -713,13 +713,13 @@ class WorkflowRunner(Generic[StateType]):
|
|
713
713
|
)
|
714
714
|
|
715
715
|
def _reject_workflow_event(
|
716
|
-
self, error: WorkflowError,
|
716
|
+
self, error: WorkflowError, captured_stacktrace: Optional[str] = None
|
717
717
|
) -> WorkflowExecutionRejectedEvent:
|
718
|
-
if
|
718
|
+
if captured_stacktrace is None:
|
719
719
|
try:
|
720
|
-
|
721
|
-
if
|
722
|
-
|
720
|
+
captured_stacktrace = traceback.format_exc()
|
721
|
+
if captured_stacktrace.strip() == "NoneType: None":
|
722
|
+
captured_stacktrace = None
|
723
723
|
except Exception:
|
724
724
|
pass
|
725
725
|
|
@@ -729,7 +729,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
729
729
|
body=WorkflowExecutionRejectedBody(
|
730
730
|
workflow_definition=self.workflow.__class__,
|
731
731
|
error=error,
|
732
|
-
|
732
|
+
stacktrace=captured_stacktrace,
|
733
733
|
),
|
734
734
|
parent=self._execution_context.parent_context,
|
735
735
|
)
|
@@ -773,21 +773,21 @@ class WorkflowRunner(Generic[StateType]):
|
|
773
773
|
else:
|
774
774
|
self._concurrency_queue.put((self._initial_state, node_cls, None))
|
775
775
|
except NodeException as e:
|
776
|
-
|
777
|
-
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error,
|
776
|
+
captured_stacktrace = traceback.format_exc()
|
777
|
+
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error, captured_stacktrace))
|
778
778
|
return
|
779
779
|
except WorkflowInitializationException as e:
|
780
|
-
|
781
|
-
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error,
|
780
|
+
captured_stacktrace = traceback.format_exc()
|
781
|
+
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error, captured_stacktrace))
|
782
782
|
return
|
783
783
|
except Exception:
|
784
784
|
err_message = f"An unexpected error occurred while initializing node {node_cls.__name__}"
|
785
785
|
logger.exception(err_message)
|
786
|
-
|
786
|
+
captured_stacktrace = traceback.format_exc()
|
787
787
|
self._workflow_event_outer_queue.put(
|
788
788
|
self._reject_workflow_event(
|
789
789
|
WorkflowError(code=WorkflowErrorCode.INTERNAL_ERROR, message=err_message),
|
790
|
-
|
790
|
+
captured_stacktrace,
|
791
791
|
)
|
792
792
|
)
|
793
793
|
return
|
@@ -838,7 +838,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
838
838
|
|
839
839
|
if rejection_event:
|
840
840
|
self._workflow_event_outer_queue.put(
|
841
|
-
self._reject_workflow_event(rejection_event.error, rejection_event.body.
|
841
|
+
self._reject_workflow_event(rejection_event.error, rejection_event.body.stacktrace)
|
842
842
|
)
|
843
843
|
return
|
844
844
|
|
@@ -1,11 +1,8 @@
|
|
1
1
|
from dataclasses import asdict, is_dataclass
|
2
2
|
from datetime import datetime
|
3
3
|
import enum
|
4
|
-
import inspect
|
5
|
-
from io import StringIO
|
6
4
|
from json import JSONEncoder
|
7
5
|
from queue import Queue
|
8
|
-
import sys
|
9
6
|
from uuid import UUID
|
10
7
|
from typing import Any, Callable, Dict, Type
|
11
8
|
|
@@ -17,23 +14,6 @@ from vellum.workflows.inputs.base import BaseInputs
|
|
17
14
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
18
15
|
from vellum.workflows.ports.port import Port
|
19
16
|
from vellum.workflows.state.base import BaseState, NodeExecutionCache
|
20
|
-
from vellum.workflows.utils.functions import compile_function_definition
|
21
|
-
|
22
|
-
|
23
|
-
def virtual_open(file_path: str, mode: str = "r"):
|
24
|
-
"""
|
25
|
-
Open a file, checking VirtualFileFinder instances first before falling back to regular open().
|
26
|
-
"""
|
27
|
-
for finder in sys.meta_path:
|
28
|
-
if hasattr(finder, "loader") and hasattr(finder.loader, "_get_code"):
|
29
|
-
namespace = finder.loader.namespace
|
30
|
-
if file_path.startswith(namespace + "/"):
|
31
|
-
relative_path = file_path[len(namespace) + 1 :]
|
32
|
-
content = finder.loader._get_code(relative_path)
|
33
|
-
if content is not None:
|
34
|
-
return StringIO(content)
|
35
|
-
|
36
|
-
return open(file_path, mode)
|
37
17
|
|
38
18
|
|
39
19
|
class DefaultStateEncoder(JSONEncoder):
|
@@ -80,23 +60,6 @@ class DefaultStateEncoder(JSONEncoder):
|
|
80
60
|
if isinstance(obj, type):
|
81
61
|
return str(obj)
|
82
62
|
|
83
|
-
if callable(obj):
|
84
|
-
function_definition = compile_function_definition(obj)
|
85
|
-
source_path = inspect.getsourcefile(obj)
|
86
|
-
if source_path is not None:
|
87
|
-
with virtual_open(source_path) as f:
|
88
|
-
source_code = f.read()
|
89
|
-
else:
|
90
|
-
source_code = f"# Error: Source code not available for {obj.__name__}"
|
91
|
-
|
92
|
-
return {
|
93
|
-
"type": "CODE_EXECUTION",
|
94
|
-
"name": function_definition.name,
|
95
|
-
"description": function_definition.description,
|
96
|
-
"definition": function_definition,
|
97
|
-
"src": source_code,
|
98
|
-
}
|
99
|
-
|
100
63
|
if obj.__class__ in self.encoders:
|
101
64
|
return self.encoders[obj.__class__](obj)
|
102
65
|
|