vellum-ai 0.14.37__py3-none-any.whl → 0.14.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +10 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/reference.md +6272 -0
- vellum/client/types/__init__.py +10 -0
- vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
- vellum/client/types/test_suite_run_exec_config_request.py +4 -0
- vellum/client/types/test_suite_run_progress.py +20 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
- vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
- vellum/client/types/test_suite_run_read.py +3 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
- vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
- vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
- vellum/client/types/workflow_execution_event_error_code.py +1 -0
- vellum/plugins/pydantic.py +1 -1
- vellum/types/test_suite_run_progress.py +3 -0
- vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
- vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
- vellum/workflows/errors/types.py +1 -0
- vellum/workflows/events/node.py +2 -1
- vellum/workflows/events/tests/test_event.py +1 -0
- vellum/workflows/events/types.py +3 -40
- vellum/workflows/events/workflow.py +15 -4
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
- vellum/workflows/nodes/displayable/conftest.py +2 -6
- vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
- vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +323 -0
- vellum/workflows/runner/runner.py +78 -57
- vellum/workflows/state/base.py +177 -50
- vellum/workflows/state/tests/test_state.py +26 -20
- vellum/workflows/types/definition.py +71 -0
- vellum/workflows/types/generics.py +34 -1
- vellum/workflows/workflows/base.py +26 -19
- vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/RECORD +49 -35
- vellum_cli/push.py +2 -3
- vellum_cli/tests/test_push.py +52 -0
- vellum_ee/workflows/display/vellum.py +0 -5
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/entry_points.txt +0 -0
vellum/workflows/state/base.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
+
from contextlib import contextmanager
|
2
3
|
from copy import deepcopy
|
3
4
|
from dataclasses import field
|
4
5
|
from datetime import datetime
|
@@ -6,7 +7,7 @@ import logging
|
|
6
7
|
from queue import Queue
|
7
8
|
from threading import Lock
|
8
9
|
from uuid import UUID, uuid4
|
9
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional,
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
|
10
11
|
from typing_extensions import dataclass_transform
|
11
12
|
|
12
13
|
from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
|
@@ -16,19 +17,14 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
|
|
16
17
|
from vellum.workflows.constants import undefined
|
17
18
|
from vellum.workflows.edges.edge import Edge
|
18
19
|
from vellum.workflows.inputs.base import BaseInputs
|
19
|
-
from vellum.workflows.outputs.base import BaseOutputs
|
20
20
|
from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
|
21
|
-
from vellum.workflows.types.
|
21
|
+
from vellum.workflows.types.definition import CodeResourceDefinition, serialize_type_encoder_with_id
|
22
|
+
from vellum.workflows.types.generics import StateType, import_workflow_class, is_workflow_class
|
22
23
|
from vellum.workflows.types.stack import Stack
|
23
|
-
from vellum.workflows.types.utils import
|
24
|
-
datetime_now,
|
25
|
-
deepcopy_with_exclusions,
|
26
|
-
get_class_attr_names,
|
27
|
-
get_class_by_qualname,
|
28
|
-
infer_types,
|
29
|
-
)
|
24
|
+
from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_attr_names, infer_types
|
30
25
|
|
31
26
|
if TYPE_CHECKING:
|
27
|
+
from vellum.workflows import BaseWorkflow
|
32
28
|
from vellum.workflows.nodes.bases import BaseNode
|
33
29
|
|
34
30
|
logger = logging.getLogger(__name__)
|
@@ -94,38 +90,65 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
94
90
|
return value
|
95
91
|
|
96
92
|
|
93
|
+
NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
|
94
|
+
NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
|
95
|
+
NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
|
96
|
+
DependenciesInvoked = Dict[UUID, Set[Type["BaseNode"]]]
|
97
|
+
|
98
|
+
|
97
99
|
class NodeExecutionCache:
|
98
|
-
_node_executions_fulfilled:
|
99
|
-
_node_executions_initiated:
|
100
|
-
_node_executions_queued:
|
101
|
-
_dependencies_invoked:
|
100
|
+
_node_executions_fulfilled: NodeExecutionsFulfilled
|
101
|
+
_node_executions_initiated: NodeExecutionsInitiated
|
102
|
+
_node_executions_queued: NodeExecutionsQueued
|
103
|
+
_dependencies_invoked: DependenciesInvoked
|
102
104
|
|
103
|
-
def __init__(
|
104
|
-
self,
|
105
|
-
dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
|
106
|
-
node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
|
107
|
-
node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
|
108
|
-
node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
|
109
|
-
) -> None:
|
105
|
+
def __init__(self) -> None:
|
110
106
|
self._dependencies_invoked = defaultdict(set)
|
111
107
|
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
112
108
|
self._node_executions_initiated = defaultdict(set)
|
113
109
|
self._node_executions_queued = defaultdict(list)
|
114
110
|
|
115
|
-
|
116
|
-
|
111
|
+
@classmethod
|
112
|
+
def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
|
113
|
+
cache = cls()
|
114
|
+
|
115
|
+
dependencies_invoked = raw_data.get("dependencies_invoked")
|
116
|
+
if isinstance(dependencies_invoked, dict):
|
117
|
+
for execution_id, dependencies in dependencies_invoked.items():
|
118
|
+
cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
|
119
|
+
|
120
|
+
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
121
|
+
if isinstance(node_executions_fulfilled, dict):
|
122
|
+
for node, execution_ids in node_executions_fulfilled.items():
|
123
|
+
node_class = nodes.get(node)
|
124
|
+
if not node_class:
|
125
|
+
continue
|
117
126
|
|
118
|
-
|
119
|
-
|
120
|
-
|
127
|
+
cache._node_executions_fulfilled[node_class].extend(
|
128
|
+
UUID(execution_id) for execution_id in execution_ids
|
129
|
+
)
|
130
|
+
|
131
|
+
node_executions_initiated = raw_data.get("node_executions_initiated")
|
132
|
+
if isinstance(node_executions_initiated, dict):
|
133
|
+
for node, execution_ids in node_executions_initiated.items():
|
134
|
+
node_class = nodes.get(node)
|
135
|
+
if not node_class:
|
136
|
+
continue
|
137
|
+
|
138
|
+
cache._node_executions_initiated[node_class].update(
|
139
|
+
{UUID(execution_id) for execution_id in execution_ids}
|
140
|
+
)
|
141
|
+
|
142
|
+
node_executions_queued = raw_data.get("node_executions_queued")
|
143
|
+
if isinstance(node_executions_queued, dict):
|
144
|
+
for node, execution_ids in node_executions_queued.items():
|
145
|
+
node_class = nodes.get(node)
|
146
|
+
if not node_class:
|
147
|
+
continue
|
121
148
|
|
122
|
-
|
123
|
-
node_class = get_class_by_qualname(node)
|
124
|
-
self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
|
149
|
+
cache._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
125
150
|
|
126
|
-
|
127
|
-
node_class = get_class_by_qualname(node)
|
128
|
-
self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
151
|
+
return cache
|
129
152
|
|
130
153
|
def _invoke_dependency(
|
131
154
|
self,
|
@@ -207,8 +230,8 @@ def default_datetime_factory() -> datetime:
|
|
207
230
|
|
208
231
|
|
209
232
|
class StateMeta(UniversalBaseModel):
|
233
|
+
workflow_definition: Type["BaseWorkflow"] = field(default_factory=import_workflow_class)
|
210
234
|
id: UUID = field(default_factory=uuid4_default_factory)
|
211
|
-
trace_id: UUID = field(default_factory=uuid4_default_factory)
|
212
235
|
span_id: UUID = field(default_factory=uuid4_default_factory)
|
213
236
|
updated_ts: datetime = field(default_factory=default_datetime_factory)
|
214
237
|
workflow_inputs: BaseInputs = field(default_factory=BaseInputs)
|
@@ -219,8 +242,6 @@ class StateMeta(UniversalBaseModel):
|
|
219
242
|
__snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
|
220
243
|
|
221
244
|
def model_post_init(self, context: Any) -> None:
|
222
|
-
if self.parent:
|
223
|
-
self.trace_id = self.parent.meta.trace_id
|
224
245
|
self.__snapshot_callback__ = None
|
225
246
|
|
226
247
|
def add_snapshot_callback(self, callback: Callable[[], None]) -> None:
|
@@ -237,6 +258,25 @@ class StateMeta(UniversalBaseModel):
|
|
237
258
|
if callable(self.__snapshot_callback__):
|
238
259
|
self.__snapshot_callback__()
|
239
260
|
|
261
|
+
@field_serializer("workflow_definition")
|
262
|
+
def serialize_workflow_definition(self, workflow_definition: Type["BaseWorkflow"], _info: Any) -> Dict[str, Any]:
|
263
|
+
return serialize_type_encoder_with_id(workflow_definition)
|
264
|
+
|
265
|
+
@field_validator("workflow_definition", mode="before")
|
266
|
+
@classmethod
|
267
|
+
def deserialize_workflow_definition(cls, workflow_definition: Any, info: ValidationInfo):
|
268
|
+
if isinstance(workflow_definition, dict):
|
269
|
+
deserialized_workflow_definition = CodeResourceDefinition.model_validate(workflow_definition).decode()
|
270
|
+
if not is_workflow_class(deserialized_workflow_definition):
|
271
|
+
return import_workflow_class()
|
272
|
+
|
273
|
+
return deserialized_workflow_definition
|
274
|
+
|
275
|
+
if is_workflow_class(workflow_definition):
|
276
|
+
return workflow_definition
|
277
|
+
|
278
|
+
return import_workflow_class()
|
279
|
+
|
240
280
|
@field_serializer("node_outputs")
|
241
281
|
def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
|
242
282
|
return {str(descriptor): value for descriptor, value in node_outputs.items()}
|
@@ -244,17 +284,16 @@ class StateMeta(UniversalBaseModel):
|
|
244
284
|
@field_validator("node_outputs", mode="before")
|
245
285
|
@classmethod
|
246
286
|
def deserialize_node_outputs(cls, node_outputs: Any, info: ValidationInfo):
|
247
|
-
if isinstance(node_outputs, dict)
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
for node in raw_workflow_nodes:
|
252
|
-
Outputs = getattr(node, "Outputs", None)
|
253
|
-
if not isinstance(Outputs, type) or not issubclass(Outputs, BaseOutputs):
|
254
|
-
continue
|
287
|
+
if isinstance(node_outputs, dict):
|
288
|
+
workflow_definition = cls._get_workflow(info)
|
289
|
+
if not workflow_definition:
|
290
|
+
return node_outputs
|
255
291
|
|
256
|
-
|
257
|
-
|
292
|
+
raw_workflow_nodes = workflow_definition.get_nodes()
|
293
|
+
workflow_node_outputs = {}
|
294
|
+
for node in raw_workflow_nodes:
|
295
|
+
for output in node.Outputs:
|
296
|
+
workflow_node_outputs[str(output)] = output
|
258
297
|
|
259
298
|
node_output_keys = list(node_outputs.keys())
|
260
299
|
deserialized_node_outputs = {}
|
@@ -269,12 +308,64 @@ class StateMeta(UniversalBaseModel):
|
|
269
308
|
|
270
309
|
return node_outputs
|
271
310
|
|
311
|
+
@field_validator("node_execution_cache", mode="before")
|
312
|
+
@classmethod
|
313
|
+
def deserialize_node_execution_cache(cls, node_execution_cache: Any, info: ValidationInfo):
|
314
|
+
if isinstance(node_execution_cache, dict):
|
315
|
+
workflow_definition = cls._get_workflow(info)
|
316
|
+
if not workflow_definition:
|
317
|
+
return node_execution_cache
|
318
|
+
|
319
|
+
nodes_cache: Dict[str, Type["BaseNode"]] = {}
|
320
|
+
raw_workflow_nodes = workflow_definition.get_nodes()
|
321
|
+
for node in raw_workflow_nodes:
|
322
|
+
nodes_cache[str(node)] = node
|
323
|
+
|
324
|
+
return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
|
325
|
+
|
326
|
+
return node_execution_cache
|
327
|
+
|
328
|
+
@field_validator("workflow_inputs", mode="before")
|
329
|
+
@classmethod
|
330
|
+
def deserialize_workflow_inputs(cls, workflow_inputs: Any, info: ValidationInfo):
|
331
|
+
workflow_definition = cls._get_workflow(info)
|
332
|
+
|
333
|
+
if workflow_definition:
|
334
|
+
if workflow_inputs is None:
|
335
|
+
return workflow_definition.get_inputs_class()()
|
336
|
+
if isinstance(workflow_inputs, dict):
|
337
|
+
return workflow_definition.get_inputs_class()(**workflow_inputs)
|
338
|
+
|
339
|
+
return workflow_inputs
|
340
|
+
|
272
341
|
@field_serializer("external_inputs")
|
273
342
|
def serialize_external_inputs(
|
274
343
|
self, external_inputs: Dict[ExternalInputReference, Any], _info: Any
|
275
344
|
) -> Dict[str, Any]:
|
276
345
|
return {str(descriptor): value for descriptor, value in external_inputs.items()}
|
277
346
|
|
347
|
+
@field_validator("parent", mode="before")
|
348
|
+
@classmethod
|
349
|
+
def deserialize_parent(cls, parent: Any, info: ValidationInfo):
|
350
|
+
if isinstance(parent, dict):
|
351
|
+
workflow_definition = cls._get_workflow(info)
|
352
|
+
if not workflow_definition:
|
353
|
+
return parent
|
354
|
+
|
355
|
+
parent_meta = parent.get("meta")
|
356
|
+
if not isinstance(parent_meta, dict):
|
357
|
+
return parent
|
358
|
+
|
359
|
+
parent_workflow_definition = cls.deserialize_workflow_definition(
|
360
|
+
parent_meta.get("workflow_definition"), info
|
361
|
+
)
|
362
|
+
if not is_workflow_class(parent_workflow_definition):
|
363
|
+
return parent
|
364
|
+
|
365
|
+
return parent_workflow_definition.deserialize_state(parent)
|
366
|
+
|
367
|
+
return parent
|
368
|
+
|
278
369
|
def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "StateMeta":
|
279
370
|
if not memo:
|
280
371
|
memo = {}
|
@@ -295,16 +386,30 @@ class StateMeta(UniversalBaseModel):
|
|
295
386
|
|
296
387
|
return super().__deepcopy__(memo)
|
297
388
|
|
389
|
+
@classmethod
|
390
|
+
def _get_workflow(cls, info: ValidationInfo) -> Optional[Type["BaseWorkflow"]]:
|
391
|
+
if not info.context:
|
392
|
+
return None
|
393
|
+
|
394
|
+
if not isinstance(info.context, dict):
|
395
|
+
return None
|
396
|
+
|
397
|
+
workflow_definition = info.context.get("workflow_definition")
|
398
|
+
if not is_workflow_class(workflow_definition):
|
399
|
+
return None
|
400
|
+
|
401
|
+
return workflow_definition
|
402
|
+
|
298
403
|
|
299
404
|
class BaseState(metaclass=_BaseStateMeta):
|
300
405
|
meta: StateMeta = field(init=False)
|
301
406
|
|
302
407
|
__lock__: Lock = field(init=False)
|
303
|
-
|
408
|
+
__is_quiet__: bool = field(init=False)
|
304
409
|
__snapshot_callback__: Callable[["BaseState"], None] = field(init=False)
|
305
410
|
|
306
411
|
def __init__(self, meta: Optional[StateMeta] = None, **kwargs: Any) -> None:
|
307
|
-
self.
|
412
|
+
self.__is_quiet__ = True
|
308
413
|
self.__snapshot_callback__ = lambda state: None
|
309
414
|
self.__lock__ = Lock()
|
310
415
|
|
@@ -314,14 +419,14 @@ class BaseState(metaclass=_BaseStateMeta):
|
|
314
419
|
# Make all class attribute values snapshottable
|
315
420
|
for name, value in self.__class__.__dict__.items():
|
316
421
|
if not name.startswith("_") and name != "meta":
|
317
|
-
# Bypass
|
422
|
+
# Bypass __is_quiet__ instead of `setattr`
|
318
423
|
snapshottable_value = _make_snapshottable(value, self.__snapshot__)
|
319
424
|
super().__setattr__(name, snapshottable_value)
|
320
425
|
|
321
426
|
for name, value in kwargs.items():
|
322
427
|
setattr(self, name, value)
|
323
428
|
|
324
|
-
self.
|
429
|
+
self.__is_quiet__ = False
|
325
430
|
|
326
431
|
def __deepcopy__(self, memo: Any) -> "BaseState":
|
327
432
|
new_state = deepcopy_with_exclusions(
|
@@ -368,7 +473,7 @@ class BaseState(metaclass=_BaseStateMeta):
|
|
368
473
|
return self.__dict__[key]
|
369
474
|
|
370
475
|
def __setattr__(self, name: str, value: Any) -> None:
|
371
|
-
if name.startswith("_")
|
476
|
+
if name.startswith("_"):
|
372
477
|
super().__setattr__(name, value)
|
373
478
|
return
|
374
479
|
|
@@ -409,11 +514,33 @@ class BaseState(metaclass=_BaseStateMeta):
|
|
409
514
|
Snapshots the current state to the workflow emitter. The invoked callback is overridden by the
|
410
515
|
workflow runner.
|
411
516
|
"""
|
517
|
+
if self.__is_quiet__:
|
518
|
+
return
|
519
|
+
|
412
520
|
try:
|
413
521
|
self.__snapshot_callback__(deepcopy(self))
|
414
522
|
except Exception:
|
415
523
|
logger.exception("Failed to snapshot Workflow state.")
|
416
524
|
|
525
|
+
@contextmanager
|
526
|
+
def __quiet__(self):
|
527
|
+
prev = self.__is_quiet__
|
528
|
+
self.__is_quiet__ = True
|
529
|
+
try:
|
530
|
+
yield
|
531
|
+
finally:
|
532
|
+
self.__is_quiet__ = prev
|
533
|
+
|
534
|
+
@contextmanager
|
535
|
+
def __atomic__(self):
|
536
|
+
prev = self.__is_quiet__
|
537
|
+
self.__is_quiet__ = True
|
538
|
+
try:
|
539
|
+
yield
|
540
|
+
finally:
|
541
|
+
self.__is_quiet__ = prev
|
542
|
+
self.__snapshot__()
|
543
|
+
|
417
544
|
@classmethod
|
418
545
|
def __get_pydantic_core_schema__(
|
419
546
|
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
|
@@ -1,17 +1,14 @@
|
|
1
1
|
import pytest
|
2
|
-
from collections import defaultdict
|
3
2
|
from copy import deepcopy
|
4
3
|
import json
|
5
4
|
from queue import Queue
|
6
|
-
from typing import Dict
|
5
|
+
from typing import Dict, cast
|
7
6
|
|
8
7
|
from vellum.workflows.nodes.bases import BaseNode
|
9
8
|
from vellum.workflows.outputs.base import BaseOutputs
|
10
9
|
from vellum.workflows.state.base import BaseState
|
11
10
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
12
11
|
|
13
|
-
snapshot_count: Dict[int, int] = defaultdict(int)
|
14
|
-
|
15
12
|
|
16
13
|
@pytest.fixture()
|
17
14
|
def mock_deepcopy(mocker):
|
@@ -27,9 +24,19 @@ class MockState(BaseState):
|
|
27
24
|
foo: str
|
28
25
|
nested_dict: Dict[str, int] = {}
|
29
26
|
|
30
|
-
|
31
|
-
|
32
|
-
|
27
|
+
__snapshot_count__: int = 0
|
28
|
+
|
29
|
+
def __init__(self, *args, **kwargs) -> None:
|
30
|
+
super().__init__(*args, **kwargs)
|
31
|
+
self.__snapshot_callback__ = lambda _: self.__mock_snapshot__()
|
32
|
+
|
33
|
+
def __mock_snapshot__(self) -> None:
|
34
|
+
self.__snapshot_count__ += 1
|
35
|
+
|
36
|
+
def __deepcopy__(self, memo: dict) -> "MockState":
|
37
|
+
new_state = cast(MockState, super().__deepcopy__(memo))
|
38
|
+
new_state.__snapshot_count__ = 0
|
39
|
+
return new_state
|
33
40
|
|
34
41
|
|
35
42
|
class MockNode(BaseNode):
|
@@ -43,50 +50,50 @@ class MockNode(BaseNode):
|
|
43
50
|
def test_state_snapshot__node_attribute_edit():
|
44
51
|
# GIVEN an initial state instance
|
45
52
|
state = MockState(foo="bar")
|
46
|
-
assert
|
53
|
+
assert state.__snapshot_count__ == 0
|
47
54
|
|
48
55
|
# WHEN we edit an attribute
|
49
56
|
state.foo = "baz"
|
50
57
|
|
51
58
|
# THEN the snapshot is emitted
|
52
|
-
assert
|
59
|
+
assert state.__snapshot_count__ == 1
|
53
60
|
|
54
61
|
|
55
62
|
def test_state_snapshot__node_output_edit():
|
56
63
|
# GIVEN an initial state instance
|
57
64
|
state = MockState(foo="bar")
|
58
|
-
assert
|
65
|
+
assert state.__snapshot_count__ == 0
|
59
66
|
|
60
67
|
# WHEN we add a Node Output to state
|
61
68
|
for output in MockNode.Outputs:
|
62
69
|
state.meta.node_outputs[output] = "hello"
|
63
70
|
|
64
71
|
# THEN the snapshot is emitted
|
65
|
-
assert
|
72
|
+
assert state.__snapshot_count__ == 1
|
66
73
|
|
67
74
|
|
68
75
|
def test_state_snapshot__nested_dictionary_edit():
|
69
76
|
# GIVEN an initial state instance
|
70
77
|
state = MockState(foo="bar")
|
71
|
-
assert
|
78
|
+
assert state.__snapshot_count__ == 0
|
72
79
|
|
73
80
|
# WHEN we edit a nested dictionary
|
74
81
|
state.nested_dict["hello"] = 1
|
75
82
|
|
76
83
|
# THEN the snapshot is emitted
|
77
|
-
assert
|
84
|
+
assert state.__snapshot_count__ == 1
|
78
85
|
|
79
86
|
|
80
87
|
def test_state_snapshot__external_input_edit():
|
81
88
|
# GIVEN an initial state instance
|
82
89
|
state = MockState(foo="bar")
|
83
|
-
assert
|
90
|
+
assert state.__snapshot_count__ == 0
|
84
91
|
|
85
92
|
# WHEN we add an external input to state
|
86
93
|
state.meta.external_inputs[MockNode.ExternalInputs.message] = "hello"
|
87
94
|
|
88
95
|
# THEN the snapshot is emitted
|
89
|
-
assert
|
96
|
+
assert state.__snapshot_count__ == 1
|
90
97
|
|
91
98
|
|
92
99
|
def test_state_deepcopy():
|
@@ -103,7 +110,6 @@ def test_state_deepcopy():
|
|
103
110
|
assert deepcopied_state.meta.node_outputs == state.meta.node_outputs
|
104
111
|
|
105
112
|
|
106
|
-
@pytest.mark.skip(reason="https://app.shortcut.com/vellum/story/5654")
|
107
113
|
def test_state_deepcopy__with_node_output_updates():
|
108
114
|
# GIVEN an initial state instance
|
109
115
|
state = MockState(foo="bar")
|
@@ -121,10 +127,10 @@ def test_state_deepcopy__with_node_output_updates():
|
|
121
127
|
assert deepcopied_state.meta.node_outputs[MockNode.Outputs.baz] == "hello"
|
122
128
|
|
123
129
|
# AND the original state has had the correct number of snapshots
|
124
|
-
assert
|
130
|
+
assert state.__snapshot_count__ == 2
|
125
131
|
|
126
132
|
# AND the copied state has had the correct number of snapshots
|
127
|
-
assert
|
133
|
+
assert deepcopied_state.__snapshot_count__ == 0
|
128
134
|
|
129
135
|
|
130
136
|
def test_state_json_serialization__with_node_output_updates():
|
@@ -158,10 +164,10 @@ def test_state_deepcopy__with_external_input_updates():
|
|
158
164
|
assert deepcopied_state.meta.external_inputs[MockNode.ExternalInputs.message] == "hello"
|
159
165
|
|
160
166
|
# AND the original state has had the correct number of snapshots
|
161
|
-
assert
|
167
|
+
assert state.__snapshot_count__ == 2
|
162
168
|
|
163
169
|
# AND the copied state has had the correct number of snapshots
|
164
|
-
assert
|
170
|
+
assert deepcopied_state.__snapshot_count__ == 0
|
165
171
|
|
166
172
|
|
167
173
|
def test_state_json_serialization__with_queue():
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import importlib
|
2
|
+
import inspect
|
3
|
+
from types import FrameType
|
4
|
+
from uuid import UUID
|
5
|
+
from typing import Annotated, Any, Dict, Optional, Union
|
6
|
+
|
7
|
+
from pydantic import BeforeValidator
|
8
|
+
|
9
|
+
from vellum.client.types.code_resource_definition import CodeResourceDefinition as ClientCodeResourceDefinition
|
10
|
+
|
11
|
+
|
12
|
+
def serialize_type_encoder(obj: type) -> Dict[str, Any]:
|
13
|
+
return {
|
14
|
+
"name": obj.__name__,
|
15
|
+
"module": obj.__module__.split("."),
|
16
|
+
}
|
17
|
+
|
18
|
+
|
19
|
+
def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
|
20
|
+
if hasattr(obj, "__id__") and isinstance(obj, type):
|
21
|
+
return {
|
22
|
+
"id": getattr(obj, "__id__"),
|
23
|
+
**serialize_type_encoder(obj),
|
24
|
+
}
|
25
|
+
elif isinstance(obj, CodeResourceDefinition):
|
26
|
+
return obj.model_dump(mode="json")
|
27
|
+
|
28
|
+
raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
|
29
|
+
|
30
|
+
|
31
|
+
class CodeResourceDefinition(ClientCodeResourceDefinition):
|
32
|
+
id: UUID
|
33
|
+
|
34
|
+
@staticmethod
|
35
|
+
def encode(obj: type) -> "CodeResourceDefinition":
|
36
|
+
return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
|
37
|
+
|
38
|
+
def decode(self) -> Any:
|
39
|
+
if ".<locals>." in self.name:
|
40
|
+
# We are decoding a local class that should already be loaded in our stack frame. So
|
41
|
+
# we climb up to look for it.
|
42
|
+
frame = inspect.currentframe()
|
43
|
+
return self._resolve_local(frame)
|
44
|
+
|
45
|
+
try:
|
46
|
+
imported_module = importlib.import_module(".".join(self.module))
|
47
|
+
except ImportError:
|
48
|
+
return None
|
49
|
+
|
50
|
+
return getattr(imported_module, self.name, None)
|
51
|
+
|
52
|
+
def _resolve_local(self, frame: Optional[FrameType]) -> Any:
|
53
|
+
if not frame:
|
54
|
+
return None
|
55
|
+
|
56
|
+
frame_module = frame.f_globals.get("__name__")
|
57
|
+
if not isinstance(frame_module, str) or frame_module.split(".") != self.module:
|
58
|
+
return self._resolve_local(frame.f_back)
|
59
|
+
|
60
|
+
outer, inner = self.name.split(".<locals>.")
|
61
|
+
frame_outer = frame.f_code.co_name
|
62
|
+
if frame_outer != outer:
|
63
|
+
return self._resolve_local(frame.f_back)
|
64
|
+
|
65
|
+
return frame.f_locals.get(inner)
|
66
|
+
|
67
|
+
|
68
|
+
VellumCodeResourceDefinition = Annotated[
|
69
|
+
CodeResourceDefinition,
|
70
|
+
BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
|
71
|
+
]
|
@@ -1,4 +1,6 @@
|
|
1
|
-
from
|
1
|
+
from functools import cache
|
2
|
+
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
3
|
+
from typing_extensions import TypeGuard
|
2
4
|
|
3
5
|
if TYPE_CHECKING:
|
4
6
|
from vellum.workflows import BaseWorkflow
|
@@ -12,3 +14,34 @@ StateType = TypeVar("StateType", bound="BaseState")
|
|
12
14
|
WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
|
13
15
|
InputsType = TypeVar("InputsType", bound="BaseInputs")
|
14
16
|
OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
|
17
|
+
|
18
|
+
|
19
|
+
@cache
|
20
|
+
def _import_node_class() -> Type["BaseNode"]:
|
21
|
+
"""
|
22
|
+
Helper function to help avoid circular imports.
|
23
|
+
"""
|
24
|
+
|
25
|
+
from vellum.workflows.nodes import BaseNode
|
26
|
+
|
27
|
+
return BaseNode
|
28
|
+
|
29
|
+
|
30
|
+
def import_workflow_class() -> Type["BaseWorkflow"]:
|
31
|
+
"""
|
32
|
+
Helper function to help avoid circular imports.
|
33
|
+
"""
|
34
|
+
|
35
|
+
from vellum.workflows.workflows import BaseWorkflow
|
36
|
+
|
37
|
+
return BaseWorkflow
|
38
|
+
|
39
|
+
|
40
|
+
def is_node_class(obj: Any) -> TypeGuard[Type["BaseNode"]]:
|
41
|
+
base_node_class = _import_node_class()
|
42
|
+
return isinstance(obj, type) and issubclass(obj, base_node_class)
|
43
|
+
|
44
|
+
|
45
|
+
def is_workflow_class(obj: Any) -> TypeGuard[Type["BaseWorkflow"]]:
|
46
|
+
base_workflow_class = import_workflow_class()
|
47
|
+
return isinstance(obj, type) and issubclass(obj, base_workflow_class)
|