vellum-ai 0.14.36__py3-none-any.whl → 0.14.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +8 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +6272 -0
- vellum/client/types/__init__.py +8 -0
- vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/logical_operator.py +1 -0
- vellum/client/types/test_suite_run_exec_config_request.py +4 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
- vellum/plugins/pydantic.py +1 -1
- vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
- vellum/workflows/events/node.py +2 -1
- vellum/workflows/events/types.py +3 -40
- vellum/workflows/events/workflow.py +2 -1
- vellum/workflows/inputs/base.py +2 -1
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +96 -3
- vellum/workflows/nodes/displayable/conftest.py +2 -6
- vellum/workflows/nodes/displayable/guardrail_node/node.py +15 -7
- vellum/workflows/nodes/displayable/guardrail_node/test_node.py +25 -0
- vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +297 -0
- vellum/workflows/runner/runner.py +44 -43
- vellum/workflows/state/base.py +169 -36
- vellum/workflows/types/definition.py +71 -0
- vellum/workflows/types/generics.py +34 -1
- vellum/workflows/workflows/base.py +34 -0
- vellum/workflows/workflows/tests/test_base_workflow.py +270 -0
- {vellum_ai-0.14.36.dist-info → vellum_ai-0.14.38.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.36.dist-info → vellum_ai-0.14.38.dist-info}/RECORD +52 -39
- vellum_ee/workflows/display/base.py +9 -7
- vellum_ee/workflows/display/nodes/__init__.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/note_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -0
- vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +33 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +3 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +0 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
- vellum_ee/workflows/display/types.py +6 -7
- vellum_ee/workflows/display/vellum.py +5 -9
- vellum_ee/workflows/display/workflows/base_workflow_display.py +20 -19
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +11 -37
- {vellum_ai-0.14.36.dist-info → vellum_ai-0.14.38.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.36.dist-info → vellum_ai-0.14.38.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.36.dist-info → vellum_ai-0.14.38.dist-info}/entry_points.txt +0 -0
vellum/workflows/state/base.py
CHANGED
@@ -6,10 +6,10 @@ import logging
|
|
6
6
|
from queue import Queue
|
7
7
|
from threading import Lock
|
8
8
|
from uuid import UUID, uuid4
|
9
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional,
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
|
10
10
|
from typing_extensions import dataclass_transform
|
11
11
|
|
12
|
-
from pydantic import GetCoreSchemaHandler, field_serializer
|
12
|
+
from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
|
13
13
|
from pydantic_core import core_schema
|
14
14
|
|
15
15
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
@@ -17,17 +17,13 @@ from vellum.workflows.constants import undefined
|
|
17
17
|
from vellum.workflows.edges.edge import Edge
|
18
18
|
from vellum.workflows.inputs.base import BaseInputs
|
19
19
|
from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
|
20
|
-
from vellum.workflows.types.
|
20
|
+
from vellum.workflows.types.definition import CodeResourceDefinition, serialize_type_encoder_with_id
|
21
|
+
from vellum.workflows.types.generics import StateType, import_workflow_class, is_workflow_class
|
21
22
|
from vellum.workflows.types.stack import Stack
|
22
|
-
from vellum.workflows.types.utils import
|
23
|
-
datetime_now,
|
24
|
-
deepcopy_with_exclusions,
|
25
|
-
get_class_attr_names,
|
26
|
-
get_class_by_qualname,
|
27
|
-
infer_types,
|
28
|
-
)
|
23
|
+
from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_attr_names, infer_types
|
29
24
|
|
30
25
|
if TYPE_CHECKING:
|
26
|
+
from vellum.workflows import BaseWorkflow
|
31
27
|
from vellum.workflows.nodes.bases import BaseNode
|
32
28
|
|
33
29
|
logger = logging.getLogger(__name__)
|
@@ -50,7 +46,7 @@ class _Snapshottable:
|
|
50
46
|
class _BaseStateMeta(type):
|
51
47
|
def __getattribute__(cls, name: str) -> Any:
|
52
48
|
if not name.startswith("_"):
|
53
|
-
instance = vars(cls).get(name)
|
49
|
+
instance = vars(cls).get(name, undefined)
|
54
50
|
types = infer_types(cls, name)
|
55
51
|
return StateValueReference(name=name, types=types, instance=instance)
|
56
52
|
|
@@ -93,38 +89,65 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
93
89
|
return value
|
94
90
|
|
95
91
|
|
92
|
+
NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
|
93
|
+
NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
|
94
|
+
NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
|
95
|
+
DependenciesInvoked = Dict[UUID, Set[Type["BaseNode"]]]
|
96
|
+
|
97
|
+
|
96
98
|
class NodeExecutionCache:
|
97
|
-
_node_executions_fulfilled:
|
98
|
-
_node_executions_initiated:
|
99
|
-
_node_executions_queued:
|
100
|
-
_dependencies_invoked:
|
99
|
+
_node_executions_fulfilled: NodeExecutionsFulfilled
|
100
|
+
_node_executions_initiated: NodeExecutionsInitiated
|
101
|
+
_node_executions_queued: NodeExecutionsQueued
|
102
|
+
_dependencies_invoked: DependenciesInvoked
|
101
103
|
|
102
|
-
def __init__(
|
103
|
-
self,
|
104
|
-
dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
|
105
|
-
node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
|
106
|
-
node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
|
107
|
-
node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
|
108
|
-
) -> None:
|
104
|
+
def __init__(self) -> None:
|
109
105
|
self._dependencies_invoked = defaultdict(set)
|
110
106
|
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
111
107
|
self._node_executions_initiated = defaultdict(set)
|
112
108
|
self._node_executions_queued = defaultdict(list)
|
113
109
|
|
114
|
-
|
115
|
-
|
110
|
+
@classmethod
|
111
|
+
def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
|
112
|
+
cache = cls()
|
113
|
+
|
114
|
+
dependencies_invoked = raw_data.get("dependencies_invoked")
|
115
|
+
if isinstance(dependencies_invoked, dict):
|
116
|
+
for execution_id, dependencies in dependencies_invoked.items():
|
117
|
+
cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
|
118
|
+
|
119
|
+
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
120
|
+
if isinstance(node_executions_fulfilled, dict):
|
121
|
+
for node, execution_ids in node_executions_fulfilled.items():
|
122
|
+
node_class = nodes.get(node)
|
123
|
+
if not node_class:
|
124
|
+
continue
|
116
125
|
|
117
|
-
|
118
|
-
|
119
|
-
|
126
|
+
cache._node_executions_fulfilled[node_class].extend(
|
127
|
+
UUID(execution_id) for execution_id in execution_ids
|
128
|
+
)
|
129
|
+
|
130
|
+
node_executions_initiated = raw_data.get("node_executions_initiated")
|
131
|
+
if isinstance(node_executions_initiated, dict):
|
132
|
+
for node, execution_ids in node_executions_initiated.items():
|
133
|
+
node_class = nodes.get(node)
|
134
|
+
if not node_class:
|
135
|
+
continue
|
136
|
+
|
137
|
+
cache._node_executions_initiated[node_class].update(
|
138
|
+
{UUID(execution_id) for execution_id in execution_ids}
|
139
|
+
)
|
140
|
+
|
141
|
+
node_executions_queued = raw_data.get("node_executions_queued")
|
142
|
+
if isinstance(node_executions_queued, dict):
|
143
|
+
for node, execution_ids in node_executions_queued.items():
|
144
|
+
node_class = nodes.get(node)
|
145
|
+
if not node_class:
|
146
|
+
continue
|
120
147
|
|
121
|
-
|
122
|
-
node_class = get_class_by_qualname(node)
|
123
|
-
self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
|
148
|
+
cache._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
124
149
|
|
125
|
-
|
126
|
-
node_class = get_class_by_qualname(node)
|
127
|
-
self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
150
|
+
return cache
|
128
151
|
|
129
152
|
def _invoke_dependency(
|
130
153
|
self,
|
@@ -206,8 +229,8 @@ def default_datetime_factory() -> datetime:
|
|
206
229
|
|
207
230
|
|
208
231
|
class StateMeta(UniversalBaseModel):
|
232
|
+
workflow_definition: Type["BaseWorkflow"] = field(default_factory=import_workflow_class)
|
209
233
|
id: UUID = field(default_factory=uuid4_default_factory)
|
210
|
-
trace_id: UUID = field(default_factory=uuid4_default_factory)
|
211
234
|
span_id: UUID = field(default_factory=uuid4_default_factory)
|
212
235
|
updated_ts: datetime = field(default_factory=default_datetime_factory)
|
213
236
|
workflow_inputs: BaseInputs = field(default_factory=BaseInputs)
|
@@ -218,8 +241,6 @@ class StateMeta(UniversalBaseModel):
|
|
218
241
|
__snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
|
219
242
|
|
220
243
|
def model_post_init(self, context: Any) -> None:
|
221
|
-
if self.parent:
|
222
|
-
self.trace_id = self.parent.meta.trace_id
|
223
244
|
self.__snapshot_callback__ = None
|
224
245
|
|
225
246
|
def add_snapshot_callback(self, callback: Callable[[], None]) -> None:
|
@@ -236,16 +257,114 @@ class StateMeta(UniversalBaseModel):
|
|
236
257
|
if callable(self.__snapshot_callback__):
|
237
258
|
self.__snapshot_callback__()
|
238
259
|
|
260
|
+
@field_serializer("workflow_definition")
|
261
|
+
def serialize_workflow_definition(self, workflow_definition: Type["BaseWorkflow"], _info: Any) -> Dict[str, Any]:
|
262
|
+
return serialize_type_encoder_with_id(workflow_definition)
|
263
|
+
|
264
|
+
@field_validator("workflow_definition", mode="before")
|
265
|
+
@classmethod
|
266
|
+
def deserialize_workflow_definition(cls, workflow_definition: Any, info: ValidationInfo):
|
267
|
+
if isinstance(workflow_definition, dict):
|
268
|
+
deserialized_workflow_definition = CodeResourceDefinition.model_validate(workflow_definition).decode()
|
269
|
+
if not is_workflow_class(deserialized_workflow_definition):
|
270
|
+
return import_workflow_class()
|
271
|
+
|
272
|
+
return deserialized_workflow_definition
|
273
|
+
|
274
|
+
if is_workflow_class(workflow_definition):
|
275
|
+
return workflow_definition
|
276
|
+
|
277
|
+
return import_workflow_class()
|
278
|
+
|
239
279
|
@field_serializer("node_outputs")
|
240
280
|
def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
|
241
281
|
return {str(descriptor): value for descriptor, value in node_outputs.items()}
|
242
282
|
|
283
|
+
@field_validator("node_outputs", mode="before")
|
284
|
+
@classmethod
|
285
|
+
def deserialize_node_outputs(cls, node_outputs: Any, info: ValidationInfo):
|
286
|
+
if isinstance(node_outputs, dict):
|
287
|
+
workflow_definition = cls._get_workflow(info)
|
288
|
+
if not workflow_definition:
|
289
|
+
return node_outputs
|
290
|
+
|
291
|
+
raw_workflow_nodes = workflow_definition.get_nodes()
|
292
|
+
workflow_node_outputs = {}
|
293
|
+
for node in raw_workflow_nodes:
|
294
|
+
for output in node.Outputs:
|
295
|
+
workflow_node_outputs[str(output)] = output
|
296
|
+
|
297
|
+
node_output_keys = list(node_outputs.keys())
|
298
|
+
deserialized_node_outputs = {}
|
299
|
+
for node_output_key in node_output_keys:
|
300
|
+
output_reference = workflow_node_outputs.get(node_output_key)
|
301
|
+
if not output_reference:
|
302
|
+
continue
|
303
|
+
|
304
|
+
deserialized_node_outputs[output_reference] = node_outputs[node_output_key]
|
305
|
+
|
306
|
+
return deserialized_node_outputs
|
307
|
+
|
308
|
+
return node_outputs
|
309
|
+
|
310
|
+
@field_validator("node_execution_cache", mode="before")
|
311
|
+
@classmethod
|
312
|
+
def deserialize_node_execution_cache(cls, node_execution_cache: Any, info: ValidationInfo):
|
313
|
+
if isinstance(node_execution_cache, dict):
|
314
|
+
workflow_definition = cls._get_workflow(info)
|
315
|
+
if not workflow_definition:
|
316
|
+
return node_execution_cache
|
317
|
+
|
318
|
+
nodes_cache: Dict[str, Type["BaseNode"]] = {}
|
319
|
+
raw_workflow_nodes = workflow_definition.get_nodes()
|
320
|
+
for node in raw_workflow_nodes:
|
321
|
+
nodes_cache[str(node)] = node
|
322
|
+
|
323
|
+
return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
|
324
|
+
|
325
|
+
return node_execution_cache
|
326
|
+
|
327
|
+
@field_validator("workflow_inputs", mode="before")
|
328
|
+
@classmethod
|
329
|
+
def deserialize_workflow_inputs(cls, workflow_inputs: Any, info: ValidationInfo):
|
330
|
+
workflow_definition = cls._get_workflow(info)
|
331
|
+
|
332
|
+
if workflow_definition:
|
333
|
+
if workflow_inputs is None:
|
334
|
+
return workflow_definition.get_inputs_class()()
|
335
|
+
if isinstance(workflow_inputs, dict):
|
336
|
+
return workflow_definition.get_inputs_class()(**workflow_inputs)
|
337
|
+
|
338
|
+
return workflow_inputs
|
339
|
+
|
243
340
|
@field_serializer("external_inputs")
|
244
341
|
def serialize_external_inputs(
|
245
342
|
self, external_inputs: Dict[ExternalInputReference, Any], _info: Any
|
246
343
|
) -> Dict[str, Any]:
|
247
344
|
return {str(descriptor): value for descriptor, value in external_inputs.items()}
|
248
345
|
|
346
|
+
@field_validator("parent", mode="before")
|
347
|
+
@classmethod
|
348
|
+
def deserialize_parent(cls, parent: Any, info: ValidationInfo):
|
349
|
+
if isinstance(parent, dict):
|
350
|
+
workflow_definition = cls._get_workflow(info)
|
351
|
+
if not workflow_definition:
|
352
|
+
return parent
|
353
|
+
|
354
|
+
parent_meta = parent.get("meta")
|
355
|
+
if not isinstance(parent_meta, dict):
|
356
|
+
return parent
|
357
|
+
|
358
|
+
parent_workflow_definition = cls.deserialize_workflow_definition(
|
359
|
+
parent_meta.get("workflow_definition"), info
|
360
|
+
)
|
361
|
+
if not is_workflow_class(parent_workflow_definition):
|
362
|
+
return parent
|
363
|
+
|
364
|
+
return parent_workflow_definition.deserialize_state(parent)
|
365
|
+
|
366
|
+
return parent
|
367
|
+
|
249
368
|
def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "StateMeta":
|
250
369
|
if not memo:
|
251
370
|
memo = {}
|
@@ -266,6 +385,20 @@ class StateMeta(UniversalBaseModel):
|
|
266
385
|
|
267
386
|
return super().__deepcopy__(memo)
|
268
387
|
|
388
|
+
@classmethod
|
389
|
+
def _get_workflow(cls, info: ValidationInfo) -> Optional[Type["BaseWorkflow"]]:
|
390
|
+
if not info.context:
|
391
|
+
return None
|
392
|
+
|
393
|
+
if not isinstance(info.context, dict):
|
394
|
+
return None
|
395
|
+
|
396
|
+
workflow_definition = info.context.get("workflow_definition")
|
397
|
+
if not is_workflow_class(workflow_definition):
|
398
|
+
return None
|
399
|
+
|
400
|
+
return workflow_definition
|
401
|
+
|
269
402
|
|
270
403
|
class BaseState(metaclass=_BaseStateMeta):
|
271
404
|
meta: StateMeta = field(init=False)
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import importlib
|
2
|
+
import inspect
|
3
|
+
from types import FrameType
|
4
|
+
from uuid import UUID
|
5
|
+
from typing import Annotated, Any, Dict, Optional, Union
|
6
|
+
|
7
|
+
from pydantic import BeforeValidator
|
8
|
+
|
9
|
+
from vellum.client.types.code_resource_definition import CodeResourceDefinition as ClientCodeResourceDefinition
|
10
|
+
|
11
|
+
|
12
|
+
def serialize_type_encoder(obj: type) -> Dict[str, Any]:
|
13
|
+
return {
|
14
|
+
"name": obj.__name__,
|
15
|
+
"module": obj.__module__.split("."),
|
16
|
+
}
|
17
|
+
|
18
|
+
|
19
|
+
def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
|
20
|
+
if hasattr(obj, "__id__") and isinstance(obj, type):
|
21
|
+
return {
|
22
|
+
"id": getattr(obj, "__id__"),
|
23
|
+
**serialize_type_encoder(obj),
|
24
|
+
}
|
25
|
+
elif isinstance(obj, CodeResourceDefinition):
|
26
|
+
return obj.model_dump(mode="json")
|
27
|
+
|
28
|
+
raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
|
29
|
+
|
30
|
+
|
31
|
+
class CodeResourceDefinition(ClientCodeResourceDefinition):
|
32
|
+
id: UUID
|
33
|
+
|
34
|
+
@staticmethod
|
35
|
+
def encode(obj: type) -> "CodeResourceDefinition":
|
36
|
+
return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
|
37
|
+
|
38
|
+
def decode(self) -> Any:
|
39
|
+
if ".<locals>." in self.name:
|
40
|
+
# We are decoding a local class that should already be loaded in our stack frame. So
|
41
|
+
# we climb up to look for it.
|
42
|
+
frame = inspect.currentframe()
|
43
|
+
return self._resolve_local(frame)
|
44
|
+
|
45
|
+
try:
|
46
|
+
imported_module = importlib.import_module(".".join(self.module))
|
47
|
+
except ImportError:
|
48
|
+
return None
|
49
|
+
|
50
|
+
return getattr(imported_module, self.name, None)
|
51
|
+
|
52
|
+
def _resolve_local(self, frame: Optional[FrameType]) -> Any:
|
53
|
+
if not frame:
|
54
|
+
return None
|
55
|
+
|
56
|
+
frame_module = frame.f_globals.get("__name__")
|
57
|
+
if not isinstance(frame_module, str) or frame_module.split(".") != self.module:
|
58
|
+
return self._resolve_local(frame.f_back)
|
59
|
+
|
60
|
+
outer, inner = self.name.split(".<locals>.")
|
61
|
+
frame_outer = frame.f_code.co_name
|
62
|
+
if frame_outer != outer:
|
63
|
+
return self._resolve_local(frame.f_back)
|
64
|
+
|
65
|
+
return frame.f_locals.get(inner)
|
66
|
+
|
67
|
+
|
68
|
+
VellumCodeResourceDefinition = Annotated[
|
69
|
+
CodeResourceDefinition,
|
70
|
+
BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
|
71
|
+
]
|
@@ -1,4 +1,6 @@
|
|
1
|
-
from
|
1
|
+
from functools import cache
|
2
|
+
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
3
|
+
from typing_extensions import TypeGuard
|
2
4
|
|
3
5
|
if TYPE_CHECKING:
|
4
6
|
from vellum.workflows import BaseWorkflow
|
@@ -12,3 +14,34 @@ StateType = TypeVar("StateType", bound="BaseState")
|
|
12
14
|
WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
|
13
15
|
InputsType = TypeVar("InputsType", bound="BaseInputs")
|
14
16
|
OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
|
17
|
+
|
18
|
+
|
19
|
+
@cache
|
20
|
+
def _import_node_class() -> Type["BaseNode"]:
|
21
|
+
"""
|
22
|
+
Helper function to help avoid circular imports.
|
23
|
+
"""
|
24
|
+
|
25
|
+
from vellum.workflows.nodes import BaseNode
|
26
|
+
|
27
|
+
return BaseNode
|
28
|
+
|
29
|
+
|
30
|
+
def import_workflow_class() -> Type["BaseWorkflow"]:
|
31
|
+
"""
|
32
|
+
Helper function to help avoid circular imports.
|
33
|
+
"""
|
34
|
+
|
35
|
+
from vellum.workflows.workflows import BaseWorkflow
|
36
|
+
|
37
|
+
return BaseWorkflow
|
38
|
+
|
39
|
+
|
40
|
+
def is_node_class(obj: Any) -> TypeGuard[Type["BaseNode"]]:
|
41
|
+
base_node_class = _import_node_class()
|
42
|
+
return isinstance(obj, type) and issubclass(obj, base_node_class)
|
43
|
+
|
44
|
+
|
45
|
+
def is_workflow_class(obj: Any) -> TypeGuard[Type["BaseWorkflow"]]:
|
46
|
+
base_workflow_class = import_workflow_class()
|
47
|
+
return isinstance(obj, type) and issubclass(obj, base_workflow_class)
|
@@ -22,6 +22,7 @@ from typing import (
|
|
22
22
|
Union,
|
23
23
|
cast,
|
24
24
|
get_args,
|
25
|
+
overload,
|
25
26
|
)
|
26
27
|
|
27
28
|
from vellum.workflows.edges import Edge
|
@@ -488,11 +489,13 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
488
489
|
parent=self._parent_state,
|
489
490
|
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
490
491
|
trace_id=execution_context.trace_id,
|
492
|
+
workflow_definition=self.__class__,
|
491
493
|
)
|
492
494
|
if execution_context and int(execution_context.trace_id)
|
493
495
|
else StateMeta(
|
494
496
|
parent=self._parent_state,
|
495
497
|
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
498
|
+
workflow_definition=self.__class__,
|
496
499
|
)
|
497
500
|
)
|
498
501
|
)
|
@@ -530,6 +533,35 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
530
533
|
|
531
534
|
return most_recent_state_snapshot
|
532
535
|
|
536
|
+
@overload
|
537
|
+
@classmethod
|
538
|
+
def deserialize_state(cls, state: dict, workflow_inputs: Optional[InputsType] = None) -> StateType: ...
|
539
|
+
|
540
|
+
@overload
|
541
|
+
@classmethod
|
542
|
+
def deserialize_state(cls, state: None, workflow_inputs: Optional[InputsType] = None) -> None: ...
|
543
|
+
|
544
|
+
@classmethod
|
545
|
+
def deserialize_state(
|
546
|
+
cls, state: Optional[dict], workflow_inputs: Optional[InputsType] = None
|
547
|
+
) -> Optional[StateType]:
|
548
|
+
if state is None:
|
549
|
+
return None
|
550
|
+
|
551
|
+
state_class = cls.get_state_class()
|
552
|
+
if "meta" in state:
|
553
|
+
state["meta"] = StateMeta.model_validate(
|
554
|
+
{
|
555
|
+
**state["meta"],
|
556
|
+
"workflow_inputs": workflow_inputs,
|
557
|
+
},
|
558
|
+
context={
|
559
|
+
"workflow_definition": cls,
|
560
|
+
},
|
561
|
+
)
|
562
|
+
|
563
|
+
return state_class(**state)
|
564
|
+
|
533
565
|
@staticmethod
|
534
566
|
def load_from_module(module_path: str) -> Type["BaseWorkflow"]:
|
535
567
|
workflow_path = f"{module_path}.workflow"
|
@@ -584,3 +616,5 @@ NodeExecutionRejectedEvent.model_rebuild()
|
|
584
616
|
NodeExecutionPausedEvent.model_rebuild()
|
585
617
|
NodeExecutionResumedEvent.model_rebuild()
|
586
618
|
NodeExecutionStreamingEvent.model_rebuild()
|
619
|
+
|
620
|
+
StateMeta.model_rebuild()
|