vellum-ai 0.14.37__py3-none-any.whl → 0.14.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +8 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +6272 -0
- vellum/client/types/__init__.py +8 -0
- vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/test_suite_run_exec_config_request.py +4 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
- vellum/plugins/pydantic.py +1 -1
- vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
- vellum/workflows/events/node.py +2 -1
- vellum/workflows/events/types.py +3 -40
- vellum/workflows/events/workflow.py +2 -1
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
- vellum/workflows/nodes/displayable/conftest.py +2 -6
- vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
- vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +297 -0
- vellum/workflows/runner/runner.py +44 -43
- vellum/workflows/state/base.py +149 -45
- vellum/workflows/types/definition.py +71 -0
- vellum/workflows/types/generics.py +34 -1
- vellum/workflows/workflows/base.py +20 -3
- vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/RECORD +37 -25
- vellum_ee/workflows/display/vellum.py +0 -5
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/entry_points.txt +0 -0
vellum/workflows/state/base.py
CHANGED
@@ -6,7 +6,7 @@ import logging
|
|
6
6
|
from queue import Queue
|
7
7
|
from threading import Lock
|
8
8
|
from uuid import UUID, uuid4
|
9
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional,
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
|
10
10
|
from typing_extensions import dataclass_transform
|
11
11
|
|
12
12
|
from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
|
@@ -16,19 +16,14 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
|
|
16
16
|
from vellum.workflows.constants import undefined
|
17
17
|
from vellum.workflows.edges.edge import Edge
|
18
18
|
from vellum.workflows.inputs.base import BaseInputs
|
19
|
-
from vellum.workflows.outputs.base import BaseOutputs
|
20
19
|
from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
|
21
|
-
from vellum.workflows.types.
|
20
|
+
from vellum.workflows.types.definition import CodeResourceDefinition, serialize_type_encoder_with_id
|
21
|
+
from vellum.workflows.types.generics import StateType, import_workflow_class, is_workflow_class
|
22
22
|
from vellum.workflows.types.stack import Stack
|
23
|
-
from vellum.workflows.types.utils import
|
24
|
-
datetime_now,
|
25
|
-
deepcopy_with_exclusions,
|
26
|
-
get_class_attr_names,
|
27
|
-
get_class_by_qualname,
|
28
|
-
infer_types,
|
29
|
-
)
|
23
|
+
from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_attr_names, infer_types
|
30
24
|
|
31
25
|
if TYPE_CHECKING:
|
26
|
+
from vellum.workflows import BaseWorkflow
|
32
27
|
from vellum.workflows.nodes.bases import BaseNode
|
33
28
|
|
34
29
|
logger = logging.getLogger(__name__)
|
@@ -94,38 +89,65 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
94
89
|
return value
|
95
90
|
|
96
91
|
|
92
|
+
NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
|
93
|
+
NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
|
94
|
+
NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
|
95
|
+
DependenciesInvoked = Dict[UUID, Set[Type["BaseNode"]]]
|
96
|
+
|
97
|
+
|
97
98
|
class NodeExecutionCache:
|
98
|
-
_node_executions_fulfilled:
|
99
|
-
_node_executions_initiated:
|
100
|
-
_node_executions_queued:
|
101
|
-
_dependencies_invoked:
|
99
|
+
_node_executions_fulfilled: NodeExecutionsFulfilled
|
100
|
+
_node_executions_initiated: NodeExecutionsInitiated
|
101
|
+
_node_executions_queued: NodeExecutionsQueued
|
102
|
+
_dependencies_invoked: DependenciesInvoked
|
102
103
|
|
103
|
-
def __init__(
|
104
|
-
self,
|
105
|
-
dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
|
106
|
-
node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
|
107
|
-
node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
|
108
|
-
node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
|
109
|
-
) -> None:
|
104
|
+
def __init__(self) -> None:
|
110
105
|
self._dependencies_invoked = defaultdict(set)
|
111
106
|
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
112
107
|
self._node_executions_initiated = defaultdict(set)
|
113
108
|
self._node_executions_queued = defaultdict(list)
|
114
109
|
|
115
|
-
|
116
|
-
|
110
|
+
@classmethod
|
111
|
+
def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
|
112
|
+
cache = cls()
|
113
|
+
|
114
|
+
dependencies_invoked = raw_data.get("dependencies_invoked")
|
115
|
+
if isinstance(dependencies_invoked, dict):
|
116
|
+
for execution_id, dependencies in dependencies_invoked.items():
|
117
|
+
cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
|
118
|
+
|
119
|
+
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
120
|
+
if isinstance(node_executions_fulfilled, dict):
|
121
|
+
for node, execution_ids in node_executions_fulfilled.items():
|
122
|
+
node_class = nodes.get(node)
|
123
|
+
if not node_class:
|
124
|
+
continue
|
125
|
+
|
126
|
+
cache._node_executions_fulfilled[node_class].extend(
|
127
|
+
UUID(execution_id) for execution_id in execution_ids
|
128
|
+
)
|
129
|
+
|
130
|
+
node_executions_initiated = raw_data.get("node_executions_initiated")
|
131
|
+
if isinstance(node_executions_initiated, dict):
|
132
|
+
for node, execution_ids in node_executions_initiated.items():
|
133
|
+
node_class = nodes.get(node)
|
134
|
+
if not node_class:
|
135
|
+
continue
|
136
|
+
|
137
|
+
cache._node_executions_initiated[node_class].update(
|
138
|
+
{UUID(execution_id) for execution_id in execution_ids}
|
139
|
+
)
|
117
140
|
|
118
|
-
|
119
|
-
|
120
|
-
|
141
|
+
node_executions_queued = raw_data.get("node_executions_queued")
|
142
|
+
if isinstance(node_executions_queued, dict):
|
143
|
+
for node, execution_ids in node_executions_queued.items():
|
144
|
+
node_class = nodes.get(node)
|
145
|
+
if not node_class:
|
146
|
+
continue
|
121
147
|
|
122
|
-
|
123
|
-
node_class = get_class_by_qualname(node)
|
124
|
-
self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
|
148
|
+
cache._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
125
149
|
|
126
|
-
|
127
|
-
node_class = get_class_by_qualname(node)
|
128
|
-
self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
150
|
+
return cache
|
129
151
|
|
130
152
|
def _invoke_dependency(
|
131
153
|
self,
|
@@ -207,8 +229,8 @@ def default_datetime_factory() -> datetime:
|
|
207
229
|
|
208
230
|
|
209
231
|
class StateMeta(UniversalBaseModel):
|
232
|
+
workflow_definition: Type["BaseWorkflow"] = field(default_factory=import_workflow_class)
|
210
233
|
id: UUID = field(default_factory=uuid4_default_factory)
|
211
|
-
trace_id: UUID = field(default_factory=uuid4_default_factory)
|
212
234
|
span_id: UUID = field(default_factory=uuid4_default_factory)
|
213
235
|
updated_ts: datetime = field(default_factory=default_datetime_factory)
|
214
236
|
workflow_inputs: BaseInputs = field(default_factory=BaseInputs)
|
@@ -219,8 +241,6 @@ class StateMeta(UniversalBaseModel):
|
|
219
241
|
__snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
|
220
242
|
|
221
243
|
def model_post_init(self, context: Any) -> None:
|
222
|
-
if self.parent:
|
223
|
-
self.trace_id = self.parent.meta.trace_id
|
224
244
|
self.__snapshot_callback__ = None
|
225
245
|
|
226
246
|
def add_snapshot_callback(self, callback: Callable[[], None]) -> None:
|
@@ -237,6 +257,25 @@ class StateMeta(UniversalBaseModel):
|
|
237
257
|
if callable(self.__snapshot_callback__):
|
238
258
|
self.__snapshot_callback__()
|
239
259
|
|
260
|
+
@field_serializer("workflow_definition")
|
261
|
+
def serialize_workflow_definition(self, workflow_definition: Type["BaseWorkflow"], _info: Any) -> Dict[str, Any]:
|
262
|
+
return serialize_type_encoder_with_id(workflow_definition)
|
263
|
+
|
264
|
+
@field_validator("workflow_definition", mode="before")
|
265
|
+
@classmethod
|
266
|
+
def deserialize_workflow_definition(cls, workflow_definition: Any, info: ValidationInfo):
|
267
|
+
if isinstance(workflow_definition, dict):
|
268
|
+
deserialized_workflow_definition = CodeResourceDefinition.model_validate(workflow_definition).decode()
|
269
|
+
if not is_workflow_class(deserialized_workflow_definition):
|
270
|
+
return import_workflow_class()
|
271
|
+
|
272
|
+
return deserialized_workflow_definition
|
273
|
+
|
274
|
+
if is_workflow_class(workflow_definition):
|
275
|
+
return workflow_definition
|
276
|
+
|
277
|
+
return import_workflow_class()
|
278
|
+
|
240
279
|
@field_serializer("node_outputs")
|
241
280
|
def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
|
242
281
|
return {str(descriptor): value for descriptor, value in node_outputs.items()}
|
@@ -244,17 +283,16 @@ class StateMeta(UniversalBaseModel):
|
|
244
283
|
@field_validator("node_outputs", mode="before")
|
245
284
|
@classmethod
|
246
285
|
def deserialize_node_outputs(cls, node_outputs: Any, info: ValidationInfo):
|
247
|
-
if isinstance(node_outputs, dict)
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
for node in raw_workflow_nodes:
|
252
|
-
Outputs = getattr(node, "Outputs", None)
|
253
|
-
if not isinstance(Outputs, type) or not issubclass(Outputs, BaseOutputs):
|
254
|
-
continue
|
286
|
+
if isinstance(node_outputs, dict):
|
287
|
+
workflow_definition = cls._get_workflow(info)
|
288
|
+
if not workflow_definition:
|
289
|
+
return node_outputs
|
255
290
|
|
256
|
-
|
257
|
-
|
291
|
+
raw_workflow_nodes = workflow_definition.get_nodes()
|
292
|
+
workflow_node_outputs = {}
|
293
|
+
for node in raw_workflow_nodes:
|
294
|
+
for output in node.Outputs:
|
295
|
+
workflow_node_outputs[str(output)] = output
|
258
296
|
|
259
297
|
node_output_keys = list(node_outputs.keys())
|
260
298
|
deserialized_node_outputs = {}
|
@@ -269,12 +307,64 @@ class StateMeta(UniversalBaseModel):
|
|
269
307
|
|
270
308
|
return node_outputs
|
271
309
|
|
310
|
+
@field_validator("node_execution_cache", mode="before")
|
311
|
+
@classmethod
|
312
|
+
def deserialize_node_execution_cache(cls, node_execution_cache: Any, info: ValidationInfo):
|
313
|
+
if isinstance(node_execution_cache, dict):
|
314
|
+
workflow_definition = cls._get_workflow(info)
|
315
|
+
if not workflow_definition:
|
316
|
+
return node_execution_cache
|
317
|
+
|
318
|
+
nodes_cache: Dict[str, Type["BaseNode"]] = {}
|
319
|
+
raw_workflow_nodes = workflow_definition.get_nodes()
|
320
|
+
for node in raw_workflow_nodes:
|
321
|
+
nodes_cache[str(node)] = node
|
322
|
+
|
323
|
+
return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
|
324
|
+
|
325
|
+
return node_execution_cache
|
326
|
+
|
327
|
+
@field_validator("workflow_inputs", mode="before")
|
328
|
+
@classmethod
|
329
|
+
def deserialize_workflow_inputs(cls, workflow_inputs: Any, info: ValidationInfo):
|
330
|
+
workflow_definition = cls._get_workflow(info)
|
331
|
+
|
332
|
+
if workflow_definition:
|
333
|
+
if workflow_inputs is None:
|
334
|
+
return workflow_definition.get_inputs_class()()
|
335
|
+
if isinstance(workflow_inputs, dict):
|
336
|
+
return workflow_definition.get_inputs_class()(**workflow_inputs)
|
337
|
+
|
338
|
+
return workflow_inputs
|
339
|
+
|
272
340
|
@field_serializer("external_inputs")
|
273
341
|
def serialize_external_inputs(
|
274
342
|
self, external_inputs: Dict[ExternalInputReference, Any], _info: Any
|
275
343
|
) -> Dict[str, Any]:
|
276
344
|
return {str(descriptor): value for descriptor, value in external_inputs.items()}
|
277
345
|
|
346
|
+
@field_validator("parent", mode="before")
|
347
|
+
@classmethod
|
348
|
+
def deserialize_parent(cls, parent: Any, info: ValidationInfo):
|
349
|
+
if isinstance(parent, dict):
|
350
|
+
workflow_definition = cls._get_workflow(info)
|
351
|
+
if not workflow_definition:
|
352
|
+
return parent
|
353
|
+
|
354
|
+
parent_meta = parent.get("meta")
|
355
|
+
if not isinstance(parent_meta, dict):
|
356
|
+
return parent
|
357
|
+
|
358
|
+
parent_workflow_definition = cls.deserialize_workflow_definition(
|
359
|
+
parent_meta.get("workflow_definition"), info
|
360
|
+
)
|
361
|
+
if not is_workflow_class(parent_workflow_definition):
|
362
|
+
return parent
|
363
|
+
|
364
|
+
return parent_workflow_definition.deserialize_state(parent)
|
365
|
+
|
366
|
+
return parent
|
367
|
+
|
278
368
|
def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "StateMeta":
|
279
369
|
if not memo:
|
280
370
|
memo = {}
|
@@ -295,6 +385,20 @@ class StateMeta(UniversalBaseModel):
|
|
295
385
|
|
296
386
|
return super().__deepcopy__(memo)
|
297
387
|
|
388
|
+
@classmethod
|
389
|
+
def _get_workflow(cls, info: ValidationInfo) -> Optional[Type["BaseWorkflow"]]:
|
390
|
+
if not info.context:
|
391
|
+
return None
|
392
|
+
|
393
|
+
if not isinstance(info.context, dict):
|
394
|
+
return None
|
395
|
+
|
396
|
+
workflow_definition = info.context.get("workflow_definition")
|
397
|
+
if not is_workflow_class(workflow_definition):
|
398
|
+
return None
|
399
|
+
|
400
|
+
return workflow_definition
|
401
|
+
|
298
402
|
|
299
403
|
class BaseState(metaclass=_BaseStateMeta):
|
300
404
|
meta: StateMeta = field(init=False)
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import importlib
|
2
|
+
import inspect
|
3
|
+
from types import FrameType
|
4
|
+
from uuid import UUID
|
5
|
+
from typing import Annotated, Any, Dict, Optional, Union
|
6
|
+
|
7
|
+
from pydantic import BeforeValidator
|
8
|
+
|
9
|
+
from vellum.client.types.code_resource_definition import CodeResourceDefinition as ClientCodeResourceDefinition
|
10
|
+
|
11
|
+
|
12
|
+
def serialize_type_encoder(obj: type) -> Dict[str, Any]:
|
13
|
+
return {
|
14
|
+
"name": obj.__name__,
|
15
|
+
"module": obj.__module__.split("."),
|
16
|
+
}
|
17
|
+
|
18
|
+
|
19
|
+
def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
|
20
|
+
if hasattr(obj, "__id__") and isinstance(obj, type):
|
21
|
+
return {
|
22
|
+
"id": getattr(obj, "__id__"),
|
23
|
+
**serialize_type_encoder(obj),
|
24
|
+
}
|
25
|
+
elif isinstance(obj, CodeResourceDefinition):
|
26
|
+
return obj.model_dump(mode="json")
|
27
|
+
|
28
|
+
raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
|
29
|
+
|
30
|
+
|
31
|
+
class CodeResourceDefinition(ClientCodeResourceDefinition):
|
32
|
+
id: UUID
|
33
|
+
|
34
|
+
@staticmethod
|
35
|
+
def encode(obj: type) -> "CodeResourceDefinition":
|
36
|
+
return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
|
37
|
+
|
38
|
+
def decode(self) -> Any:
|
39
|
+
if ".<locals>." in self.name:
|
40
|
+
# We are decoding a local class that should already be loaded in our stack frame. So
|
41
|
+
# we climb up to look for it.
|
42
|
+
frame = inspect.currentframe()
|
43
|
+
return self._resolve_local(frame)
|
44
|
+
|
45
|
+
try:
|
46
|
+
imported_module = importlib.import_module(".".join(self.module))
|
47
|
+
except ImportError:
|
48
|
+
return None
|
49
|
+
|
50
|
+
return getattr(imported_module, self.name, None)
|
51
|
+
|
52
|
+
def _resolve_local(self, frame: Optional[FrameType]) -> Any:
|
53
|
+
if not frame:
|
54
|
+
return None
|
55
|
+
|
56
|
+
frame_module = frame.f_globals.get("__name__")
|
57
|
+
if not isinstance(frame_module, str) or frame_module.split(".") != self.module:
|
58
|
+
return self._resolve_local(frame.f_back)
|
59
|
+
|
60
|
+
outer, inner = self.name.split(".<locals>.")
|
61
|
+
frame_outer = frame.f_code.co_name
|
62
|
+
if frame_outer != outer:
|
63
|
+
return self._resolve_local(frame.f_back)
|
64
|
+
|
65
|
+
return frame.f_locals.get(inner)
|
66
|
+
|
67
|
+
|
68
|
+
VellumCodeResourceDefinition = Annotated[
|
69
|
+
CodeResourceDefinition,
|
70
|
+
BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
|
71
|
+
]
|
@@ -1,4 +1,6 @@
|
|
1
|
-
from
|
1
|
+
from functools import cache
|
2
|
+
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
3
|
+
from typing_extensions import TypeGuard
|
2
4
|
|
3
5
|
if TYPE_CHECKING:
|
4
6
|
from vellum.workflows import BaseWorkflow
|
@@ -12,3 +14,34 @@ StateType = TypeVar("StateType", bound="BaseState")
|
|
12
14
|
WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
|
13
15
|
InputsType = TypeVar("InputsType", bound="BaseInputs")
|
14
16
|
OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
|
17
|
+
|
18
|
+
|
19
|
+
@cache
|
20
|
+
def _import_node_class() -> Type["BaseNode"]:
|
21
|
+
"""
|
22
|
+
Helper function to help avoid circular imports.
|
23
|
+
"""
|
24
|
+
|
25
|
+
from vellum.workflows.nodes import BaseNode
|
26
|
+
|
27
|
+
return BaseNode
|
28
|
+
|
29
|
+
|
30
|
+
def import_workflow_class() -> Type["BaseWorkflow"]:
|
31
|
+
"""
|
32
|
+
Helper function to help avoid circular imports.
|
33
|
+
"""
|
34
|
+
|
35
|
+
from vellum.workflows.workflows import BaseWorkflow
|
36
|
+
|
37
|
+
return BaseWorkflow
|
38
|
+
|
39
|
+
|
40
|
+
def is_node_class(obj: Any) -> TypeGuard[Type["BaseNode"]]:
|
41
|
+
base_node_class = _import_node_class()
|
42
|
+
return isinstance(obj, type) and issubclass(obj, base_node_class)
|
43
|
+
|
44
|
+
|
45
|
+
def is_workflow_class(obj: Any) -> TypeGuard[Type["BaseWorkflow"]]:
|
46
|
+
base_workflow_class = import_workflow_class()
|
47
|
+
return isinstance(obj, type) and issubclass(obj, base_workflow_class)
|
@@ -22,6 +22,7 @@ from typing import (
|
|
22
22
|
Union,
|
23
23
|
cast,
|
24
24
|
get_args,
|
25
|
+
overload,
|
25
26
|
)
|
26
27
|
|
27
28
|
from vellum.workflows.edges import Edge
|
@@ -488,11 +489,13 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
488
489
|
parent=self._parent_state,
|
489
490
|
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
490
491
|
trace_id=execution_context.trace_id,
|
492
|
+
workflow_definition=self.__class__,
|
491
493
|
)
|
492
494
|
if execution_context and int(execution_context.trace_id)
|
493
495
|
else StateMeta(
|
494
496
|
parent=self._parent_state,
|
495
497
|
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
498
|
+
workflow_definition=self.__class__,
|
496
499
|
)
|
497
500
|
)
|
498
501
|
)
|
@@ -530,18 +533,30 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
530
533
|
|
531
534
|
return most_recent_state_snapshot
|
532
535
|
|
536
|
+
@overload
|
537
|
+
@classmethod
|
538
|
+
def deserialize_state(cls, state: dict, workflow_inputs: Optional[InputsType] = None) -> StateType: ...
|
539
|
+
|
540
|
+
@overload
|
533
541
|
@classmethod
|
534
|
-
def deserialize_state(cls, state:
|
542
|
+
def deserialize_state(cls, state: None, workflow_inputs: Optional[InputsType] = None) -> None: ...
|
543
|
+
|
544
|
+
@classmethod
|
545
|
+
def deserialize_state(
|
546
|
+
cls, state: Optional[dict], workflow_inputs: Optional[InputsType] = None
|
547
|
+
) -> Optional[StateType]:
|
548
|
+
if state is None:
|
549
|
+
return None
|
550
|
+
|
535
551
|
state_class = cls.get_state_class()
|
536
552
|
if "meta" in state:
|
537
|
-
nodes = list(cls.get_nodes())
|
538
553
|
state["meta"] = StateMeta.model_validate(
|
539
554
|
{
|
540
555
|
**state["meta"],
|
541
556
|
"workflow_inputs": workflow_inputs,
|
542
557
|
},
|
543
558
|
context={
|
544
|
-
"
|
559
|
+
"workflow_definition": cls,
|
545
560
|
},
|
546
561
|
)
|
547
562
|
|
@@ -601,3 +616,5 @@ NodeExecutionRejectedEvent.model_rebuild()
|
|
601
616
|
NodeExecutionPausedEvent.model_rebuild()
|
602
617
|
NodeExecutionResumedEvent.model_rebuild()
|
603
618
|
NodeExecutionStreamingEvent.model_rebuild()
|
619
|
+
|
620
|
+
StateMeta.model_rebuild()
|