vellum-ai 0.14.38__py3-none-any.whl → 0.14.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +2 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/__init__.py +2 -0
- vellum/client/types/test_suite_run_progress.py +20 -0
- vellum/client/types/test_suite_run_read.py +3 -0
- vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
- vellum/client/types/workflow_execution_event_error_code.py +1 -0
- vellum/types/test_suite_run_progress.py +3 -0
- vellum/workflows/errors/types.py +1 -0
- vellum/workflows/events/tests/test_event.py +1 -0
- vellum/workflows/events/workflow.py +13 -3
- vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
- vellum/workflows/nodes/core/try_node/node.py +1 -2
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
- vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +26 -0
- vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
- vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
- vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
- vellum/workflows/nodes/utils.py +4 -2
- vellum/workflows/outputs/base.py +3 -2
- vellum/workflows/references/output.py +20 -0
- vellum/workflows/runner/runner.py +37 -17
- vellum/workflows/state/base.py +64 -19
- vellum/workflows/state/tests/test_state.py +31 -22
- vellum/workflows/types/stack.py +11 -0
- vellum/workflows/workflows/base.py +13 -18
- vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +82 -75
- vellum_cli/push.py +2 -5
- vellum_cli/tests/test_push.py +52 -0
- vellum_ee/workflows/display/base.py +14 -1
- vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
- vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
- vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
- vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
- vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
- vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
- vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
- vellum_ee/workflows/display/types.py +5 -4
- vellum_ee/workflows/display/utils/exceptions.py +7 -0
- vellum_ee/workflows/display/utils/registry.py +37 -0
- vellum_ee/workflows/display/utils/vellum.py +2 -1
- vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
- vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
- vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
- {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
@@ -4,11 +4,11 @@ from dataclasses import dataclass
|
|
4
4
|
import logging
|
5
5
|
from queue import Empty, Queue
|
6
6
|
from threading import Event as ThreadingEvent, Thread
|
7
|
-
from uuid import UUID
|
7
|
+
from uuid import UUID, uuid4
|
8
8
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
|
9
9
|
|
10
10
|
from vellum.workflows.constants import undefined
|
11
|
-
from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
|
11
|
+
from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
|
12
12
|
from vellum.workflows.descriptors.base import BaseDescriptor
|
13
13
|
from vellum.workflows.edges.edge import Edge
|
14
14
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
@@ -30,7 +30,7 @@ from vellum.workflows.events.node import (
|
|
30
30
|
NodeExecutionRejectedBody,
|
31
31
|
NodeExecutionStreamingBody,
|
32
32
|
)
|
33
|
-
from vellum.workflows.events.types import BaseEvent, NodeParentContext, WorkflowParentContext
|
33
|
+
from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, WorkflowParentContext
|
34
34
|
from vellum.workflows.events.workflow import (
|
35
35
|
WorkflowExecutionFulfilledBody,
|
36
36
|
WorkflowExecutionInitiatedBody,
|
@@ -90,6 +90,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
90
90
|
|
91
91
|
self.workflow = workflow
|
92
92
|
self._is_resuming = False
|
93
|
+
self._should_emit_initial_state = True
|
93
94
|
if entrypoint_nodes:
|
94
95
|
if len(list(entrypoint_nodes)) > 1:
|
95
96
|
raise ValueError("Cannot resume from multiple nodes")
|
@@ -98,7 +99,8 @@ class WorkflowRunner(Generic[StateType]):
|
|
98
99
|
# https://app.shortcut.com/vellum/story/4408
|
99
100
|
node = next(iter(entrypoint_nodes))
|
100
101
|
if state:
|
101
|
-
self._initial_state = state
|
102
|
+
self._initial_state = deepcopy(state)
|
103
|
+
self._initial_state.meta.span_id = uuid4()
|
102
104
|
else:
|
103
105
|
self._initial_state = self.workflow.get_state_at_node(node)
|
104
106
|
self._entrypoints = entrypoint_nodes
|
@@ -123,8 +125,13 @@ class WorkflowRunner(Generic[StateType]):
|
|
123
125
|
if state:
|
124
126
|
self._initial_state = deepcopy(state)
|
125
127
|
self._initial_state.meta.workflow_inputs = normalized_inputs
|
128
|
+
self._initial_state.meta.span_id = uuid4()
|
126
129
|
else:
|
127
130
|
self._initial_state = self.workflow.get_default_state(normalized_inputs)
|
131
|
+
# We don't want to emit the initial state on the base case of Workflow Runs, since
|
132
|
+
# all of that data is redundant and is derivable. It also clearly communicates that
|
133
|
+
# there was no initial state provided by the user to invoke the workflow.
|
134
|
+
self._should_emit_initial_state = False
|
128
135
|
self._entrypoints = self.workflow.get_entrypoints()
|
129
136
|
|
130
137
|
# This queue is responsible for sending events from WorkflowRunner to the outside world
|
@@ -239,7 +246,8 @@ class WorkflowRunner(Generic[StateType]):
|
|
239
246
|
instance=None,
|
240
247
|
outputs_class=node.Outputs,
|
241
248
|
)
|
242
|
-
node.state.
|
249
|
+
with node.state.__quiet__():
|
250
|
+
node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
|
243
251
|
initiated_output: BaseOutput = BaseOutput(name=output.name)
|
244
252
|
initiated_ports = initiated_output > ports
|
245
253
|
self._workflow_event_inner_queue.put(
|
@@ -297,13 +305,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
297
305
|
|
298
306
|
node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
|
299
307
|
|
300
|
-
|
301
|
-
|
302
|
-
if
|
303
|
-
|
304
|
-
|
308
|
+
with node.state.__atomic__():
|
309
|
+
for descriptor, output_value in outputs:
|
310
|
+
if output_value is undefined:
|
311
|
+
if descriptor in node.state.meta.node_outputs:
|
312
|
+
del node.state.meta.node_outputs[descriptor]
|
313
|
+
continue
|
305
314
|
|
306
|
-
|
315
|
+
node.state.meta.node_outputs[descriptor] = output_value
|
307
316
|
|
308
317
|
invoked_ports = ports(outputs, node.state)
|
309
318
|
self._workflow_event_inner_queue.put(
|
@@ -365,11 +374,16 @@ class WorkflowRunner(Generic[StateType]):
|
|
365
374
|
|
366
375
|
logger.debug(f"Finished running node: {node.__class__.__name__}")
|
367
376
|
|
368
|
-
def _context_run_work_item(
|
369
|
-
|
377
|
+
def _context_run_work_item(
|
378
|
+
self,
|
379
|
+
node: BaseNode[StateType],
|
380
|
+
span_id: UUID,
|
381
|
+
parent_context: ParentContext,
|
382
|
+
trace_id: UUID,
|
383
|
+
) -> None:
|
370
384
|
with execution_context(
|
371
|
-
parent_context=parent_context
|
372
|
-
trace_id=
|
385
|
+
parent_context=parent_context,
|
386
|
+
trace_id=trace_id,
|
373
387
|
):
|
374
388
|
self._run_work_item(node, span_id)
|
375
389
|
|
@@ -419,14 +433,19 @@ class WorkflowRunner(Generic[StateType]):
|
|
419
433
|
if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
|
420
434
|
return
|
421
435
|
|
422
|
-
|
436
|
+
execution = get_execution_context()
|
423
437
|
node = node_class(state=state, context=self.workflow.context)
|
424
438
|
state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
|
425
439
|
self._active_nodes_by_execution_id[node_span_id] = ActiveNode(node=node)
|
426
440
|
|
427
441
|
worker_thread = Thread(
|
428
442
|
target=self._context_run_work_item,
|
429
|
-
kwargs={
|
443
|
+
kwargs={
|
444
|
+
"node": node,
|
445
|
+
"span_id": node_span_id,
|
446
|
+
"parent_context": execution.parent_context,
|
447
|
+
"trace_id": execution.trace_id,
|
448
|
+
},
|
430
449
|
)
|
431
450
|
worker_thread.start()
|
432
451
|
|
@@ -500,6 +519,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
500
519
|
body=WorkflowExecutionInitiatedBody(
|
501
520
|
workflow_definition=self.workflow.__class__,
|
502
521
|
inputs=self._initial_state.meta.workflow_inputs,
|
522
|
+
initial_state=deepcopy(self._initial_state) if self._should_emit_initial_state else None,
|
503
523
|
),
|
504
524
|
parent=self._execution_context.parent_context,
|
505
525
|
)
|
vellum/workflows/state/base.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
+
from contextlib import contextmanager
|
2
3
|
from copy import deepcopy
|
3
4
|
from dataclasses import field
|
4
5
|
from datetime import datetime
|
@@ -6,13 +7,14 @@ import logging
|
|
6
7
|
from queue import Queue
|
7
8
|
from threading import Lock
|
8
9
|
from uuid import UUID, uuid4
|
9
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
|
10
11
|
from typing_extensions import dataclass_transform
|
11
12
|
|
12
13
|
from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
|
13
14
|
from pydantic_core import core_schema
|
14
15
|
|
15
16
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
17
|
+
from vellum.utils.uuid import is_valid_uuid
|
16
18
|
from vellum.workflows.constants import undefined
|
17
19
|
from vellum.workflows.edges.edge import Edge
|
18
20
|
from vellum.workflows.inputs.base import BaseInputs
|
@@ -108,18 +110,30 @@ class NodeExecutionCache:
|
|
108
110
|
self._node_executions_queued = defaultdict(list)
|
109
111
|
|
110
112
|
@classmethod
|
111
|
-
def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
|
113
|
+
def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
|
112
114
|
cache = cls()
|
113
115
|
|
116
|
+
def get_node_class(node_id: Any) -> Optional[Type["BaseNode"]]:
|
117
|
+
if not isinstance(node_id, str):
|
118
|
+
return None
|
119
|
+
|
120
|
+
if is_valid_uuid(node_id):
|
121
|
+
return nodes.get(UUID(node_id))
|
122
|
+
|
123
|
+
return nodes.get(node_id)
|
124
|
+
|
114
125
|
dependencies_invoked = raw_data.get("dependencies_invoked")
|
115
126
|
if isinstance(dependencies_invoked, dict):
|
116
127
|
for execution_id, dependencies in dependencies_invoked.items():
|
117
|
-
|
128
|
+
dependency_classes = {get_node_class(dep) for dep in dependencies}
|
129
|
+
cache._dependencies_invoked[UUID(execution_id)] = {
|
130
|
+
dep_class for dep_class in dependency_classes if dep_class is not None
|
131
|
+
}
|
118
132
|
|
119
133
|
node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
|
120
134
|
if isinstance(node_executions_fulfilled, dict):
|
121
135
|
for node, execution_ids in node_executions_fulfilled.items():
|
122
|
-
node_class =
|
136
|
+
node_class = get_node_class(node)
|
123
137
|
if not node_class:
|
124
138
|
continue
|
125
139
|
|
@@ -130,7 +144,7 @@ class NodeExecutionCache:
|
|
130
144
|
node_executions_initiated = raw_data.get("node_executions_initiated")
|
131
145
|
if isinstance(node_executions_initiated, dict):
|
132
146
|
for node, execution_ids in node_executions_initiated.items():
|
133
|
-
node_class =
|
147
|
+
node_class = get_node_class(node)
|
134
148
|
if not node_class:
|
135
149
|
continue
|
136
150
|
|
@@ -141,7 +155,7 @@ class NodeExecutionCache:
|
|
141
155
|
node_executions_queued = raw_data.get("node_executions_queued")
|
142
156
|
if isinstance(node_executions_queued, dict):
|
143
157
|
for node, execution_ids in node_executions_queued.items():
|
144
|
-
node_class =
|
158
|
+
node_class = get_node_class(node)
|
145
159
|
if not node_class:
|
146
160
|
continue
|
147
161
|
|
@@ -192,17 +206,18 @@ class NodeExecutionCache:
|
|
192
206
|
def dump(self) -> Dict[str, Any]:
|
193
207
|
return {
|
194
208
|
"dependencies_invoked": {
|
195
|
-
str(execution_id): [str(dep) for dep in dependencies]
|
209
|
+
str(execution_id): [str(dep.__id__) for dep in dependencies]
|
196
210
|
for execution_id, dependencies in self._dependencies_invoked.items()
|
197
211
|
},
|
198
212
|
"node_executions_initiated": {
|
199
|
-
str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
213
|
+
str(node.__id__): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
200
214
|
},
|
201
215
|
"node_executions_fulfilled": {
|
202
|
-
str(node): execution_ids.dump()
|
216
|
+
str(node.__id__): execution_ids.dump()
|
217
|
+
for node, execution_ids in self._node_executions_fulfilled.items()
|
203
218
|
},
|
204
219
|
"node_executions_queued": {
|
205
|
-
str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
220
|
+
str(node.__id__): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
206
221
|
},
|
207
222
|
}
|
208
223
|
|
@@ -278,7 +293,7 @@ class StateMeta(UniversalBaseModel):
|
|
278
293
|
|
279
294
|
@field_serializer("node_outputs")
|
280
295
|
def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
|
281
|
-
return {str(descriptor): value for descriptor, value in node_outputs.items()}
|
296
|
+
return {str(descriptor.id): value for descriptor, value in node_outputs.items()}
|
282
297
|
|
283
298
|
@field_validator("node_outputs", mode="before")
|
284
299
|
@classmethod
|
@@ -289,15 +304,22 @@ class StateMeta(UniversalBaseModel):
|
|
289
304
|
return node_outputs
|
290
305
|
|
291
306
|
raw_workflow_nodes = workflow_definition.get_nodes()
|
292
|
-
workflow_node_outputs = {}
|
307
|
+
workflow_node_outputs: Dict[Union[str, UUID], OutputReference] = {}
|
293
308
|
for node in raw_workflow_nodes:
|
294
309
|
for output in node.Outputs:
|
295
310
|
workflow_node_outputs[str(output)] = output
|
311
|
+
output_id = node.__output_ids__.get(output.name)
|
312
|
+
if output_id:
|
313
|
+
workflow_node_outputs[output_id] = output
|
296
314
|
|
297
315
|
node_output_keys = list(node_outputs.keys())
|
298
316
|
deserialized_node_outputs = {}
|
299
317
|
for node_output_key in node_output_keys:
|
300
|
-
|
318
|
+
if is_valid_uuid(node_output_key):
|
319
|
+
output_reference = workflow_node_outputs.get(UUID(node_output_key))
|
320
|
+
else:
|
321
|
+
output_reference = workflow_node_outputs.get(node_output_key)
|
322
|
+
|
301
323
|
if not output_reference:
|
302
324
|
continue
|
303
325
|
|
@@ -315,10 +337,11 @@ class StateMeta(UniversalBaseModel):
|
|
315
337
|
if not workflow_definition:
|
316
338
|
return node_execution_cache
|
317
339
|
|
318
|
-
nodes_cache: Dict[str, Type["BaseNode"]] = {}
|
340
|
+
nodes_cache: Dict[Union[str, UUID], Type["BaseNode"]] = {}
|
319
341
|
raw_workflow_nodes = workflow_definition.get_nodes()
|
320
342
|
for node in raw_workflow_nodes:
|
321
343
|
nodes_cache[str(node)] = node
|
344
|
+
nodes_cache[node.__id__] = node
|
322
345
|
|
323
346
|
return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
|
324
347
|
|
@@ -404,11 +427,11 @@ class BaseState(metaclass=_BaseStateMeta):
|
|
404
427
|
meta: StateMeta = field(init=False)
|
405
428
|
|
406
429
|
__lock__: Lock = field(init=False)
|
407
|
-
|
430
|
+
__is_quiet__: bool = field(init=False)
|
408
431
|
__snapshot_callback__: Callable[["BaseState"], None] = field(init=False)
|
409
432
|
|
410
433
|
def __init__(self, meta: Optional[StateMeta] = None, **kwargs: Any) -> None:
|
411
|
-
self.
|
434
|
+
self.__is_quiet__ = True
|
412
435
|
self.__snapshot_callback__ = lambda state: None
|
413
436
|
self.__lock__ = Lock()
|
414
437
|
|
@@ -418,14 +441,14 @@ class BaseState(metaclass=_BaseStateMeta):
|
|
418
441
|
# Make all class attribute values snapshottable
|
419
442
|
for name, value in self.__class__.__dict__.items():
|
420
443
|
if not name.startswith("_") and name != "meta":
|
421
|
-
# Bypass
|
444
|
+
# Bypass __is_quiet__ instead of `setattr`
|
422
445
|
snapshottable_value = _make_snapshottable(value, self.__snapshot__)
|
423
446
|
super().__setattr__(name, snapshottable_value)
|
424
447
|
|
425
448
|
for name, value in kwargs.items():
|
426
449
|
setattr(self, name, value)
|
427
450
|
|
428
|
-
self.
|
451
|
+
self.__is_quiet__ = False
|
429
452
|
|
430
453
|
def __deepcopy__(self, memo: Any) -> "BaseState":
|
431
454
|
new_state = deepcopy_with_exclusions(
|
@@ -472,7 +495,7 @@ class BaseState(metaclass=_BaseStateMeta):
|
|
472
495
|
return self.__dict__[key]
|
473
496
|
|
474
497
|
def __setattr__(self, name: str, value: Any) -> None:
|
475
|
-
if name.startswith("_")
|
498
|
+
if name.startswith("_"):
|
476
499
|
super().__setattr__(name, value)
|
477
500
|
return
|
478
501
|
|
@@ -513,11 +536,33 @@ class BaseState(metaclass=_BaseStateMeta):
|
|
513
536
|
Snapshots the current state to the workflow emitter. The invoked callback is overridden by the
|
514
537
|
workflow runner.
|
515
538
|
"""
|
539
|
+
if self.__is_quiet__:
|
540
|
+
return
|
541
|
+
|
516
542
|
try:
|
517
543
|
self.__snapshot_callback__(deepcopy(self))
|
518
544
|
except Exception:
|
519
545
|
logger.exception("Failed to snapshot Workflow state.")
|
520
546
|
|
547
|
+
@contextmanager
|
548
|
+
def __quiet__(self):
|
549
|
+
prev = self.__is_quiet__
|
550
|
+
self.__is_quiet__ = True
|
551
|
+
try:
|
552
|
+
yield
|
553
|
+
finally:
|
554
|
+
self.__is_quiet__ = prev
|
555
|
+
|
556
|
+
@contextmanager
|
557
|
+
def __atomic__(self):
|
558
|
+
prev = self.__is_quiet__
|
559
|
+
self.__is_quiet__ = True
|
560
|
+
try:
|
561
|
+
yield
|
562
|
+
finally:
|
563
|
+
self.__is_quiet__ = prev
|
564
|
+
self.__snapshot__()
|
565
|
+
|
521
566
|
@classmethod
|
522
567
|
def __get_pydantic_core_schema__(
|
523
568
|
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
|
@@ -1,17 +1,14 @@
|
|
1
1
|
import pytest
|
2
|
-
from collections import defaultdict
|
3
2
|
from copy import deepcopy
|
4
3
|
import json
|
5
4
|
from queue import Queue
|
6
|
-
from typing import Dict
|
5
|
+
from typing import Dict, cast
|
7
6
|
|
8
7
|
from vellum.workflows.nodes.bases import BaseNode
|
9
8
|
from vellum.workflows.outputs.base import BaseOutputs
|
10
9
|
from vellum.workflows.state.base import BaseState
|
11
10
|
from vellum.workflows.state.encoder import DefaultStateEncoder
|
12
11
|
|
13
|
-
snapshot_count: Dict[int, int] = defaultdict(int)
|
14
|
-
|
15
12
|
|
16
13
|
@pytest.fixture()
|
17
14
|
def mock_deepcopy(mocker):
|
@@ -27,9 +24,19 @@ class MockState(BaseState):
|
|
27
24
|
foo: str
|
28
25
|
nested_dict: Dict[str, int] = {}
|
29
26
|
|
30
|
-
|
31
|
-
|
32
|
-
|
27
|
+
__snapshot_count__: int = 0
|
28
|
+
|
29
|
+
def __init__(self, *args, **kwargs) -> None:
|
30
|
+
super().__init__(*args, **kwargs)
|
31
|
+
self.__snapshot_callback__ = lambda _: self.__mock_snapshot__()
|
32
|
+
|
33
|
+
def __mock_snapshot__(self) -> None:
|
34
|
+
self.__snapshot_count__ += 1
|
35
|
+
|
36
|
+
def __deepcopy__(self, memo: dict) -> "MockState":
|
37
|
+
new_state = cast(MockState, super().__deepcopy__(memo))
|
38
|
+
new_state.__snapshot_count__ = 0
|
39
|
+
return new_state
|
33
40
|
|
34
41
|
|
35
42
|
class MockNode(BaseNode):
|
@@ -40,53 +47,56 @@ class MockNode(BaseNode):
|
|
40
47
|
baz: str
|
41
48
|
|
42
49
|
|
50
|
+
MOCK_NODE_OUTPUT_ID = "e4dc3136-0c27-4bda-b3ab-ea355d5219d6"
|
51
|
+
|
52
|
+
|
43
53
|
def test_state_snapshot__node_attribute_edit():
|
44
54
|
# GIVEN an initial state instance
|
45
55
|
state = MockState(foo="bar")
|
46
|
-
assert
|
56
|
+
assert state.__snapshot_count__ == 0
|
47
57
|
|
48
58
|
# WHEN we edit an attribute
|
49
59
|
state.foo = "baz"
|
50
60
|
|
51
61
|
# THEN the snapshot is emitted
|
52
|
-
assert
|
62
|
+
assert state.__snapshot_count__ == 1
|
53
63
|
|
54
64
|
|
55
65
|
def test_state_snapshot__node_output_edit():
|
56
66
|
# GIVEN an initial state instance
|
57
67
|
state = MockState(foo="bar")
|
58
|
-
assert
|
68
|
+
assert state.__snapshot_count__ == 0
|
59
69
|
|
60
70
|
# WHEN we add a Node Output to state
|
61
71
|
for output in MockNode.Outputs:
|
62
72
|
state.meta.node_outputs[output] = "hello"
|
63
73
|
|
64
74
|
# THEN the snapshot is emitted
|
65
|
-
assert
|
75
|
+
assert state.__snapshot_count__ == 1
|
66
76
|
|
67
77
|
|
68
78
|
def test_state_snapshot__nested_dictionary_edit():
|
69
79
|
# GIVEN an initial state instance
|
70
80
|
state = MockState(foo="bar")
|
71
|
-
assert
|
81
|
+
assert state.__snapshot_count__ == 0
|
72
82
|
|
73
83
|
# WHEN we edit a nested dictionary
|
74
84
|
state.nested_dict["hello"] = 1
|
75
85
|
|
76
86
|
# THEN the snapshot is emitted
|
77
|
-
assert
|
87
|
+
assert state.__snapshot_count__ == 1
|
78
88
|
|
79
89
|
|
80
90
|
def test_state_snapshot__external_input_edit():
|
81
91
|
# GIVEN an initial state instance
|
82
92
|
state = MockState(foo="bar")
|
83
|
-
assert
|
93
|
+
assert state.__snapshot_count__ == 0
|
84
94
|
|
85
95
|
# WHEN we add an external input to state
|
86
96
|
state.meta.external_inputs[MockNode.ExternalInputs.message] = "hello"
|
87
97
|
|
88
98
|
# THEN the snapshot is emitted
|
89
|
-
assert
|
99
|
+
assert state.__snapshot_count__ == 1
|
90
100
|
|
91
101
|
|
92
102
|
def test_state_deepcopy():
|
@@ -103,7 +113,6 @@ def test_state_deepcopy():
|
|
103
113
|
assert deepcopied_state.meta.node_outputs == state.meta.node_outputs
|
104
114
|
|
105
115
|
|
106
|
-
@pytest.mark.skip(reason="https://app.shortcut.com/vellum/story/5654")
|
107
116
|
def test_state_deepcopy__with_node_output_updates():
|
108
117
|
# GIVEN an initial state instance
|
109
118
|
state = MockState(foo="bar")
|
@@ -121,10 +130,10 @@ def test_state_deepcopy__with_node_output_updates():
|
|
121
130
|
assert deepcopied_state.meta.node_outputs[MockNode.Outputs.baz] == "hello"
|
122
131
|
|
123
132
|
# AND the original state has had the correct number of snapshots
|
124
|
-
assert
|
133
|
+
assert state.__snapshot_count__ == 2
|
125
134
|
|
126
135
|
# AND the copied state has had the correct number of snapshots
|
127
|
-
assert
|
136
|
+
assert deepcopied_state.__snapshot_count__ == 0
|
128
137
|
|
129
138
|
|
130
139
|
def test_state_json_serialization__with_node_output_updates():
|
@@ -138,7 +147,7 @@ def test_state_json_serialization__with_node_output_updates():
|
|
138
147
|
json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
|
139
148
|
|
140
149
|
# THEN the state is serialized correctly
|
141
|
-
assert json_state["meta"]["node_outputs"] == {
|
150
|
+
assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: "hello"}
|
142
151
|
|
143
152
|
|
144
153
|
def test_state_deepcopy__with_external_input_updates():
|
@@ -158,10 +167,10 @@ def test_state_deepcopy__with_external_input_updates():
|
|
158
167
|
assert deepcopied_state.meta.external_inputs[MockNode.ExternalInputs.message] == "hello"
|
159
168
|
|
160
169
|
# AND the original state has had the correct number of snapshots
|
161
|
-
assert
|
170
|
+
assert state.__snapshot_count__ == 2
|
162
171
|
|
163
172
|
# AND the copied state has had the correct number of snapshots
|
164
|
-
assert
|
173
|
+
assert deepcopied_state.__snapshot_count__ == 0
|
165
174
|
|
166
175
|
|
167
176
|
def test_state_json_serialization__with_queue():
|
@@ -179,7 +188,7 @@ def test_state_json_serialization__with_queue():
|
|
179
188
|
json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
|
180
189
|
|
181
190
|
# THEN the state is serialized correctly with the queue turned into a list
|
182
|
-
assert json_state["meta"]["node_outputs"] == {
|
191
|
+
assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: ["test1", "test2"]}
|
183
192
|
|
184
193
|
|
185
194
|
def test_state_snapshot__deepcopy_fails__logs_error(mock_deepcopy, mock_logger):
|
vellum/workflows/types/stack.py
CHANGED
@@ -37,3 +37,14 @@ class Stack(Generic[_T]):
|
|
37
37
|
|
38
38
|
def dump(self) -> List[_T]:
|
39
39
|
return [item for item in self._items][::-1]
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_list(cls, items: List[_T]) -> "Stack[_T]":
|
43
|
+
stack = cls()
|
44
|
+
stack.extend(items)
|
45
|
+
return stack
|
46
|
+
|
47
|
+
def __eq__(self, other: object) -> bool:
|
48
|
+
if not isinstance(other, Stack):
|
49
|
+
return False
|
50
|
+
return self._items == other._items
|
@@ -80,6 +80,11 @@ class _BaseWorkflowMeta(type):
|
|
80
80
|
def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
81
81
|
if "graph" not in dct:
|
82
82
|
dct["graph"] = set()
|
83
|
+
for base in bases:
|
84
|
+
base_graph = getattr(base, "graph", None)
|
85
|
+
if base_graph:
|
86
|
+
dct["graph"] = base_graph
|
87
|
+
break
|
83
88
|
|
84
89
|
if "Outputs" in dct:
|
85
90
|
outputs_class = dct["Outputs"]
|
@@ -146,7 +151,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
146
151
|
|
147
152
|
WorkflowEvent = Union[ # type: ignore
|
148
153
|
GenericWorkflowEvent,
|
149
|
-
WorkflowExecutionInitiatedEvent[InputsType], # type: ignore[valid-type]
|
154
|
+
WorkflowExecutionInitiatedEvent[InputsType, StateType], # type: ignore[valid-type]
|
150
155
|
WorkflowExecutionFulfilledEvent[Outputs],
|
151
156
|
WorkflowExecutionSnapshottedEvent[StateType], # type: ignore[valid-type]
|
152
157
|
]
|
@@ -335,7 +340,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
335
340
|
|
336
341
|
if not last_event:
|
337
342
|
return WorkflowExecutionRejectedEvent(
|
338
|
-
trace_id=
|
343
|
+
trace_id=self._execution_context.trace_id,
|
339
344
|
span_id=uuid4(),
|
340
345
|
body=WorkflowExecutionRejectedBody(
|
341
346
|
error=WorkflowError(
|
@@ -348,7 +353,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
348
353
|
|
349
354
|
if not first_event:
|
350
355
|
return WorkflowExecutionRejectedEvent(
|
351
|
-
trace_id=
|
356
|
+
trace_id=self._execution_context.trace_id,
|
352
357
|
span_id=uuid4(),
|
353
358
|
body=WorkflowExecutionRejectedBody(
|
354
359
|
error=WorkflowError(
|
@@ -367,7 +372,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
367
372
|
return last_event
|
368
373
|
|
369
374
|
return WorkflowExecutionRejectedEvent(
|
370
|
-
trace_id=
|
375
|
+
trace_id=self._execution_context.trace_id,
|
371
376
|
span_id=first_event.span_id,
|
372
377
|
body=WorkflowExecutionRejectedBody(
|
373
378
|
workflow_definition=self.__class__,
|
@@ -482,21 +487,11 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
|
482
487
|
return self.get_inputs_class()()
|
483
488
|
|
484
489
|
def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
|
485
|
-
execution_context = self._execution_context
|
486
490
|
return self.get_state_class()(
|
487
|
-
meta=(
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
trace_id=execution_context.trace_id,
|
492
|
-
workflow_definition=self.__class__,
|
493
|
-
)
|
494
|
-
if execution_context and int(execution_context.trace_id)
|
495
|
-
else StateMeta(
|
496
|
-
parent=self._parent_state,
|
497
|
-
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
498
|
-
workflow_definition=self.__class__,
|
499
|
-
)
|
491
|
+
meta=StateMeta(
|
492
|
+
parent=self._parent_state,
|
493
|
+
workflow_inputs=workflow_inputs or self.get_default_inputs(),
|
494
|
+
workflow_definition=self.__class__,
|
500
495
|
)
|
501
496
|
)
|
502
497
|
|