vellum-ai 0.14.38__py3-none-any.whl → 0.14.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. vellum/__init__.py +2 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/types/__init__.py +2 -0
  4. vellum/client/types/test_suite_run_progress.py +20 -0
  5. vellum/client/types/test_suite_run_read.py +3 -0
  6. vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
  7. vellum/client/types/workflow_execution_event_error_code.py +1 -0
  8. vellum/types/test_suite_run_progress.py +3 -0
  9. vellum/workflows/errors/types.py +1 -0
  10. vellum/workflows/events/tests/test_event.py +1 -0
  11. vellum/workflows/events/workflow.py +13 -3
  12. vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
  13. vellum/workflows/nodes/core/try_node/node.py +1 -2
  14. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
  15. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
  16. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +26 -0
  17. vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
  18. vellum/workflows/nodes/experimental/tool_calling_node/node.py +147 -0
  19. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +132 -0
  20. vellum/workflows/nodes/utils.py +4 -2
  21. vellum/workflows/outputs/base.py +3 -2
  22. vellum/workflows/references/output.py +20 -0
  23. vellum/workflows/runner/runner.py +37 -17
  24. vellum/workflows/state/base.py +64 -19
  25. vellum/workflows/state/tests/test_state.py +31 -22
  26. vellum/workflows/types/stack.py +11 -0
  27. vellum/workflows/workflows/base.py +13 -18
  28. vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
  29. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/METADATA +1 -1
  30. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/RECORD +82 -75
  31. vellum_cli/push.py +2 -5
  32. vellum_cli/tests/test_push.py +52 -0
  33. vellum_ee/workflows/display/base.py +14 -1
  34. vellum_ee/workflows/display/nodes/base_node_display.py +56 -14
  35. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
  36. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +36 -0
  37. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +3 -2
  38. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
  39. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
  40. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
  41. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
  42. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
  43. vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
  45. vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
  46. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  47. vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
  48. vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
  49. vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
  50. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
  51. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
  52. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
  64. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
  65. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
  66. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
  70. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
  71. vellum_ee/workflows/display/types.py +5 -4
  72. vellum_ee/workflows/display/utils/exceptions.py +7 -0
  73. vellum_ee/workflows/display/utils/registry.py +37 -0
  74. vellum_ee/workflows/display/utils/vellum.py +2 -1
  75. vellum_ee/workflows/display/workflows/base_workflow_display.py +281 -43
  76. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
  77. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
  78. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
  79. vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
  80. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/LICENSE +0 -0
  81. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/WHEEL +0 -0
  82. {vellum_ai-0.14.38.dist-info → vellum_ai-0.14.40.dist-info}/entry_points.txt +0 -0
@@ -4,11 +4,11 @@ from dataclasses import dataclass
4
4
  import logging
5
5
  from queue import Empty, Queue
6
6
  from threading import Event as ThreadingEvent, Thread
7
- from uuid import UUID
7
+ from uuid import UUID, uuid4
8
8
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Tuple, Type, Union
9
9
 
10
10
  from vellum.workflows.constants import undefined
11
- from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context, get_parent_context
11
+ from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
12
12
  from vellum.workflows.descriptors.base import BaseDescriptor
13
13
  from vellum.workflows.edges.edge import Edge
14
14
  from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
@@ -30,7 +30,7 @@ from vellum.workflows.events.node import (
30
30
  NodeExecutionRejectedBody,
31
31
  NodeExecutionStreamingBody,
32
32
  )
33
- from vellum.workflows.events.types import BaseEvent, NodeParentContext, WorkflowParentContext
33
+ from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, WorkflowParentContext
34
34
  from vellum.workflows.events.workflow import (
35
35
  WorkflowExecutionFulfilledBody,
36
36
  WorkflowExecutionInitiatedBody,
@@ -90,6 +90,7 @@ class WorkflowRunner(Generic[StateType]):
90
90
 
91
91
  self.workflow = workflow
92
92
  self._is_resuming = False
93
+ self._should_emit_initial_state = True
93
94
  if entrypoint_nodes:
94
95
  if len(list(entrypoint_nodes)) > 1:
95
96
  raise ValueError("Cannot resume from multiple nodes")
@@ -98,7 +99,8 @@ class WorkflowRunner(Generic[StateType]):
98
99
  # https://app.shortcut.com/vellum/story/4408
99
100
  node = next(iter(entrypoint_nodes))
100
101
  if state:
101
- self._initial_state = state
102
+ self._initial_state = deepcopy(state)
103
+ self._initial_state.meta.span_id = uuid4()
102
104
  else:
103
105
  self._initial_state = self.workflow.get_state_at_node(node)
104
106
  self._entrypoints = entrypoint_nodes
@@ -123,8 +125,13 @@ class WorkflowRunner(Generic[StateType]):
123
125
  if state:
124
126
  self._initial_state = deepcopy(state)
125
127
  self._initial_state.meta.workflow_inputs = normalized_inputs
128
+ self._initial_state.meta.span_id = uuid4()
126
129
  else:
127
130
  self._initial_state = self.workflow.get_default_state(normalized_inputs)
131
+ # We don't want to emit the initial state on the base case of Workflow Runs, since
132
+ # all of that data is redundant and is derivable. It also clearly communicates that
133
+ # there was no initial state provided by the user to invoke the workflow.
134
+ self._should_emit_initial_state = False
128
135
  self._entrypoints = self.workflow.get_entrypoints()
129
136
 
130
137
  # This queue is responsible for sending events from WorkflowRunner to the outside world
@@ -239,7 +246,8 @@ class WorkflowRunner(Generic[StateType]):
239
246
  instance=None,
240
247
  outputs_class=node.Outputs,
241
248
  )
242
- node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
249
+ with node.state.__quiet__():
250
+ node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
243
251
  initiated_output: BaseOutput = BaseOutput(name=output.name)
244
252
  initiated_ports = initiated_output > ports
245
253
  self._workflow_event_inner_queue.put(
@@ -297,13 +305,14 @@ class WorkflowRunner(Generic[StateType]):
297
305
 
298
306
  node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
299
307
 
300
- for descriptor, output_value in outputs:
301
- if output_value is undefined:
302
- if descriptor in node.state.meta.node_outputs:
303
- del node.state.meta.node_outputs[descriptor]
304
- continue
308
+ with node.state.__atomic__():
309
+ for descriptor, output_value in outputs:
310
+ if output_value is undefined:
311
+ if descriptor in node.state.meta.node_outputs:
312
+ del node.state.meta.node_outputs[descriptor]
313
+ continue
305
314
 
306
- node.state.meta.node_outputs[descriptor] = output_value
315
+ node.state.meta.node_outputs[descriptor] = output_value
307
316
 
308
317
  invoked_ports = ports(outputs, node.state)
309
318
  self._workflow_event_inner_queue.put(
@@ -365,11 +374,16 @@ class WorkflowRunner(Generic[StateType]):
365
374
 
366
375
  logger.debug(f"Finished running node: {node.__class__.__name__}")
367
376
 
368
- def _context_run_work_item(self, node: BaseNode[StateType], span_id: UUID, parent_context=None) -> None:
369
- execution = get_execution_context()
377
+ def _context_run_work_item(
378
+ self,
379
+ node: BaseNode[StateType],
380
+ span_id: UUID,
381
+ parent_context: ParentContext,
382
+ trace_id: UUID,
383
+ ) -> None:
370
384
  with execution_context(
371
- parent_context=parent_context or execution.parent_context,
372
- trace_id=execution.trace_id,
385
+ parent_context=parent_context,
386
+ trace_id=trace_id,
373
387
  ):
374
388
  self._run_work_item(node, span_id)
375
389
 
@@ -419,14 +433,19 @@ class WorkflowRunner(Generic[StateType]):
419
433
  if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
420
434
  return
421
435
 
422
- current_parent = get_parent_context()
436
+ execution = get_execution_context()
423
437
  node = node_class(state=state, context=self.workflow.context)
424
438
  state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
425
439
  self._active_nodes_by_execution_id[node_span_id] = ActiveNode(node=node)
426
440
 
427
441
  worker_thread = Thread(
428
442
  target=self._context_run_work_item,
429
- kwargs={"node": node, "span_id": node_span_id, "parent_context": current_parent},
443
+ kwargs={
444
+ "node": node,
445
+ "span_id": node_span_id,
446
+ "parent_context": execution.parent_context,
447
+ "trace_id": execution.trace_id,
448
+ },
430
449
  )
431
450
  worker_thread.start()
432
451
 
@@ -500,6 +519,7 @@ class WorkflowRunner(Generic[StateType]):
500
519
  body=WorkflowExecutionInitiatedBody(
501
520
  workflow_definition=self.workflow.__class__,
502
521
  inputs=self._initial_state.meta.workflow_inputs,
522
+ initial_state=deepcopy(self._initial_state) if self._should_emit_initial_state else None,
503
523
  ),
504
524
  parent=self._execution_context.parent_context,
505
525
  )
@@ -1,4 +1,5 @@
1
1
  from collections import defaultdict
2
+ from contextlib import contextmanager
2
3
  from copy import deepcopy
3
4
  from dataclasses import field
4
5
  from datetime import datetime
@@ -6,13 +7,14 @@ import logging
6
7
  from queue import Queue
7
8
  from threading import Lock
8
9
  from uuid import UUID, uuid4
9
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
10
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
10
11
  from typing_extensions import dataclass_transform
11
12
 
12
13
  from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
13
14
  from pydantic_core import core_schema
14
15
 
15
16
  from vellum.core.pydantic_utilities import UniversalBaseModel
17
+ from vellum.utils.uuid import is_valid_uuid
16
18
  from vellum.workflows.constants import undefined
17
19
  from vellum.workflows.edges.edge import Edge
18
20
  from vellum.workflows.inputs.base import BaseInputs
@@ -108,18 +110,30 @@ class NodeExecutionCache:
108
110
  self._node_executions_queued = defaultdict(list)
109
111
 
110
112
  @classmethod
111
- def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
113
+ def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
112
114
  cache = cls()
113
115
 
116
+ def get_node_class(node_id: Any) -> Optional[Type["BaseNode"]]:
117
+ if not isinstance(node_id, str):
118
+ return None
119
+
120
+ if is_valid_uuid(node_id):
121
+ return nodes.get(UUID(node_id))
122
+
123
+ return nodes.get(node_id)
124
+
114
125
  dependencies_invoked = raw_data.get("dependencies_invoked")
115
126
  if isinstance(dependencies_invoked, dict):
116
127
  for execution_id, dependencies in dependencies_invoked.items():
117
- cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
128
+ dependency_classes = {get_node_class(dep) for dep in dependencies}
129
+ cache._dependencies_invoked[UUID(execution_id)] = {
130
+ dep_class for dep_class in dependency_classes if dep_class is not None
131
+ }
118
132
 
119
133
  node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
120
134
  if isinstance(node_executions_fulfilled, dict):
121
135
  for node, execution_ids in node_executions_fulfilled.items():
122
- node_class = nodes.get(node)
136
+ node_class = get_node_class(node)
123
137
  if not node_class:
124
138
  continue
125
139
 
@@ -130,7 +144,7 @@ class NodeExecutionCache:
130
144
  node_executions_initiated = raw_data.get("node_executions_initiated")
131
145
  if isinstance(node_executions_initiated, dict):
132
146
  for node, execution_ids in node_executions_initiated.items():
133
- node_class = nodes.get(node)
147
+ node_class = get_node_class(node)
134
148
  if not node_class:
135
149
  continue
136
150
 
@@ -141,7 +155,7 @@ class NodeExecutionCache:
141
155
  node_executions_queued = raw_data.get("node_executions_queued")
142
156
  if isinstance(node_executions_queued, dict):
143
157
  for node, execution_ids in node_executions_queued.items():
144
- node_class = nodes.get(node)
158
+ node_class = get_node_class(node)
145
159
  if not node_class:
146
160
  continue
147
161
 
@@ -192,17 +206,18 @@ class NodeExecutionCache:
192
206
  def dump(self) -> Dict[str, Any]:
193
207
  return {
194
208
  "dependencies_invoked": {
195
- str(execution_id): [str(dep) for dep in dependencies]
209
+ str(execution_id): [str(dep.__id__) for dep in dependencies]
196
210
  for execution_id, dependencies in self._dependencies_invoked.items()
197
211
  },
198
212
  "node_executions_initiated": {
199
- str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
213
+ str(node.__id__): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
200
214
  },
201
215
  "node_executions_fulfilled": {
202
- str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
216
+ str(node.__id__): execution_ids.dump()
217
+ for node, execution_ids in self._node_executions_fulfilled.items()
203
218
  },
204
219
  "node_executions_queued": {
205
- str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
220
+ str(node.__id__): execution_ids for node, execution_ids in self._node_executions_queued.items()
206
221
  },
207
222
  }
208
223
 
@@ -278,7 +293,7 @@ class StateMeta(UniversalBaseModel):
278
293
 
279
294
  @field_serializer("node_outputs")
280
295
  def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
281
- return {str(descriptor): value for descriptor, value in node_outputs.items()}
296
+ return {str(descriptor.id): value for descriptor, value in node_outputs.items()}
282
297
 
283
298
  @field_validator("node_outputs", mode="before")
284
299
  @classmethod
@@ -289,15 +304,22 @@ class StateMeta(UniversalBaseModel):
289
304
  return node_outputs
290
305
 
291
306
  raw_workflow_nodes = workflow_definition.get_nodes()
292
- workflow_node_outputs = {}
307
+ workflow_node_outputs: Dict[Union[str, UUID], OutputReference] = {}
293
308
  for node in raw_workflow_nodes:
294
309
  for output in node.Outputs:
295
310
  workflow_node_outputs[str(output)] = output
311
+ output_id = node.__output_ids__.get(output.name)
312
+ if output_id:
313
+ workflow_node_outputs[output_id] = output
296
314
 
297
315
  node_output_keys = list(node_outputs.keys())
298
316
  deserialized_node_outputs = {}
299
317
  for node_output_key in node_output_keys:
300
- output_reference = workflow_node_outputs.get(node_output_key)
318
+ if is_valid_uuid(node_output_key):
319
+ output_reference = workflow_node_outputs.get(UUID(node_output_key))
320
+ else:
321
+ output_reference = workflow_node_outputs.get(node_output_key)
322
+
301
323
  if not output_reference:
302
324
  continue
303
325
 
@@ -315,10 +337,11 @@ class StateMeta(UniversalBaseModel):
315
337
  if not workflow_definition:
316
338
  return node_execution_cache
317
339
 
318
- nodes_cache: Dict[str, Type["BaseNode"]] = {}
340
+ nodes_cache: Dict[Union[str, UUID], Type["BaseNode"]] = {}
319
341
  raw_workflow_nodes = workflow_definition.get_nodes()
320
342
  for node in raw_workflow_nodes:
321
343
  nodes_cache[str(node)] = node
344
+ nodes_cache[node.__id__] = node
322
345
 
323
346
  return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
324
347
 
@@ -404,11 +427,11 @@ class BaseState(metaclass=_BaseStateMeta):
404
427
  meta: StateMeta = field(init=False)
405
428
 
406
429
  __lock__: Lock = field(init=False)
407
- __is_initializing__: bool = field(init=False)
430
+ __is_quiet__: bool = field(init=False)
408
431
  __snapshot_callback__: Callable[["BaseState"], None] = field(init=False)
409
432
 
410
433
  def __init__(self, meta: Optional[StateMeta] = None, **kwargs: Any) -> None:
411
- self.__is_initializing__ = True
434
+ self.__is_quiet__ = True
412
435
  self.__snapshot_callback__ = lambda state: None
413
436
  self.__lock__ = Lock()
414
437
 
@@ -418,14 +441,14 @@ class BaseState(metaclass=_BaseStateMeta):
418
441
  # Make all class attribute values snapshottable
419
442
  for name, value in self.__class__.__dict__.items():
420
443
  if not name.startswith("_") and name != "meta":
421
- # Bypass __is_initializing__ instead of `setattr`
444
+ # Bypass __is_quiet__ instead of `setattr`
422
445
  snapshottable_value = _make_snapshottable(value, self.__snapshot__)
423
446
  super().__setattr__(name, snapshottable_value)
424
447
 
425
448
  for name, value in kwargs.items():
426
449
  setattr(self, name, value)
427
450
 
428
- self.__is_initializing__ = False
451
+ self.__is_quiet__ = False
429
452
 
430
453
  def __deepcopy__(self, memo: Any) -> "BaseState":
431
454
  new_state = deepcopy_with_exclusions(
@@ -472,7 +495,7 @@ class BaseState(metaclass=_BaseStateMeta):
472
495
  return self.__dict__[key]
473
496
 
474
497
  def __setattr__(self, name: str, value: Any) -> None:
475
- if name.startswith("_") or self.__is_initializing__:
498
+ if name.startswith("_"):
476
499
  super().__setattr__(name, value)
477
500
  return
478
501
 
@@ -513,11 +536,33 @@ class BaseState(metaclass=_BaseStateMeta):
513
536
  Snapshots the current state to the workflow emitter. The invoked callback is overridden by the
514
537
  workflow runner.
515
538
  """
539
+ if self.__is_quiet__:
540
+ return
541
+
516
542
  try:
517
543
  self.__snapshot_callback__(deepcopy(self))
518
544
  except Exception:
519
545
  logger.exception("Failed to snapshot Workflow state.")
520
546
 
547
+ @contextmanager
548
+ def __quiet__(self):
549
+ prev = self.__is_quiet__
550
+ self.__is_quiet__ = True
551
+ try:
552
+ yield
553
+ finally:
554
+ self.__is_quiet__ = prev
555
+
556
+ @contextmanager
557
+ def __atomic__(self):
558
+ prev = self.__is_quiet__
559
+ self.__is_quiet__ = True
560
+ try:
561
+ yield
562
+ finally:
563
+ self.__is_quiet__ = prev
564
+ self.__snapshot__()
565
+
521
566
  @classmethod
522
567
  def __get_pydantic_core_schema__(
523
568
  cls, source_type: Type[Any], handler: GetCoreSchemaHandler
@@ -1,17 +1,14 @@
1
1
  import pytest
2
- from collections import defaultdict
3
2
  from copy import deepcopy
4
3
  import json
5
4
  from queue import Queue
6
- from typing import Dict
5
+ from typing import Dict, cast
7
6
 
8
7
  from vellum.workflows.nodes.bases import BaseNode
9
8
  from vellum.workflows.outputs.base import BaseOutputs
10
9
  from vellum.workflows.state.base import BaseState
11
10
  from vellum.workflows.state.encoder import DefaultStateEncoder
12
11
 
13
- snapshot_count: Dict[int, int] = defaultdict(int)
14
-
15
12
 
16
13
  @pytest.fixture()
17
14
  def mock_deepcopy(mocker):
@@ -27,9 +24,19 @@ class MockState(BaseState):
27
24
  foo: str
28
25
  nested_dict: Dict[str, int] = {}
29
26
 
30
- def __snapshot__(self) -> None:
31
- global snapshot_count
32
- snapshot_count[id(self)] += 1
27
+ __snapshot_count__: int = 0
28
+
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ super().__init__(*args, **kwargs)
31
+ self.__snapshot_callback__ = lambda _: self.__mock_snapshot__()
32
+
33
+ def __mock_snapshot__(self) -> None:
34
+ self.__snapshot_count__ += 1
35
+
36
+ def __deepcopy__(self, memo: dict) -> "MockState":
37
+ new_state = cast(MockState, super().__deepcopy__(memo))
38
+ new_state.__snapshot_count__ = 0
39
+ return new_state
33
40
 
34
41
 
35
42
  class MockNode(BaseNode):
@@ -40,53 +47,56 @@ class MockNode(BaseNode):
40
47
  baz: str
41
48
 
42
49
 
50
+ MOCK_NODE_OUTPUT_ID = "e4dc3136-0c27-4bda-b3ab-ea355d5219d6"
51
+
52
+
43
53
  def test_state_snapshot__node_attribute_edit():
44
54
  # GIVEN an initial state instance
45
55
  state = MockState(foo="bar")
46
- assert snapshot_count[id(state)] == 0
56
+ assert state.__snapshot_count__ == 0
47
57
 
48
58
  # WHEN we edit an attribute
49
59
  state.foo = "baz"
50
60
 
51
61
  # THEN the snapshot is emitted
52
- assert snapshot_count[id(state)] == 1
62
+ assert state.__snapshot_count__ == 1
53
63
 
54
64
 
55
65
  def test_state_snapshot__node_output_edit():
56
66
  # GIVEN an initial state instance
57
67
  state = MockState(foo="bar")
58
- assert snapshot_count[id(state)] == 0
68
+ assert state.__snapshot_count__ == 0
59
69
 
60
70
  # WHEN we add a Node Output to state
61
71
  for output in MockNode.Outputs:
62
72
  state.meta.node_outputs[output] = "hello"
63
73
 
64
74
  # THEN the snapshot is emitted
65
- assert snapshot_count[id(state)] == 1
75
+ assert state.__snapshot_count__ == 1
66
76
 
67
77
 
68
78
  def test_state_snapshot__nested_dictionary_edit():
69
79
  # GIVEN an initial state instance
70
80
  state = MockState(foo="bar")
71
- assert snapshot_count[id(state)] == 0
81
+ assert state.__snapshot_count__ == 0
72
82
 
73
83
  # WHEN we edit a nested dictionary
74
84
  state.nested_dict["hello"] = 1
75
85
 
76
86
  # THEN the snapshot is emitted
77
- assert snapshot_count[id(state)] == 1
87
+ assert state.__snapshot_count__ == 1
78
88
 
79
89
 
80
90
  def test_state_snapshot__external_input_edit():
81
91
  # GIVEN an initial state instance
82
92
  state = MockState(foo="bar")
83
- assert snapshot_count[id(state)] == 0
93
+ assert state.__snapshot_count__ == 0
84
94
 
85
95
  # WHEN we add an external input to state
86
96
  state.meta.external_inputs[MockNode.ExternalInputs.message] = "hello"
87
97
 
88
98
  # THEN the snapshot is emitted
89
- assert snapshot_count[id(state)] == 1
99
+ assert state.__snapshot_count__ == 1
90
100
 
91
101
 
92
102
  def test_state_deepcopy():
@@ -103,7 +113,6 @@ def test_state_deepcopy():
103
113
  assert deepcopied_state.meta.node_outputs == state.meta.node_outputs
104
114
 
105
115
 
106
- @pytest.mark.skip(reason="https://app.shortcut.com/vellum/story/5654")
107
116
  def test_state_deepcopy__with_node_output_updates():
108
117
  # GIVEN an initial state instance
109
118
  state = MockState(foo="bar")
@@ -121,10 +130,10 @@ def test_state_deepcopy__with_node_output_updates():
121
130
  assert deepcopied_state.meta.node_outputs[MockNode.Outputs.baz] == "hello"
122
131
 
123
132
  # AND the original state has had the correct number of snapshots
124
- assert snapshot_count[id(state)] == 2
133
+ assert state.__snapshot_count__ == 2
125
134
 
126
135
  # AND the copied state has had the correct number of snapshots
127
- assert snapshot_count[id(deepcopied_state)] == 0
136
+ assert deepcopied_state.__snapshot_count__ == 0
128
137
 
129
138
 
130
139
  def test_state_json_serialization__with_node_output_updates():
@@ -138,7 +147,7 @@ def test_state_json_serialization__with_node_output_updates():
138
147
  json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
139
148
 
140
149
  # THEN the state is serialized correctly
141
- assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": "hello"}
150
+ assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: "hello"}
142
151
 
143
152
 
144
153
  def test_state_deepcopy__with_external_input_updates():
@@ -158,10 +167,10 @@ def test_state_deepcopy__with_external_input_updates():
158
167
  assert deepcopied_state.meta.external_inputs[MockNode.ExternalInputs.message] == "hello"
159
168
 
160
169
  # AND the original state has had the correct number of snapshots
161
- assert snapshot_count[id(state)] == 2
170
+ assert state.__snapshot_count__ == 2
162
171
 
163
172
  # AND the copied state has had the correct number of snapshots
164
- assert snapshot_count[id(deepcopied_state)] == 0
173
+ assert deepcopied_state.__snapshot_count__ == 0
165
174
 
166
175
 
167
176
  def test_state_json_serialization__with_queue():
@@ -179,7 +188,7 @@ def test_state_json_serialization__with_queue():
179
188
  json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
180
189
 
181
190
  # THEN the state is serialized correctly with the queue turned into a list
182
- assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": ["test1", "test2"]}
191
+ assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: ["test1", "test2"]}
183
192
 
184
193
 
185
194
  def test_state_snapshot__deepcopy_fails__logs_error(mock_deepcopy, mock_logger):
@@ -37,3 +37,14 @@ class Stack(Generic[_T]):
37
37
 
38
38
  def dump(self) -> List[_T]:
39
39
  return [item for item in self._items][::-1]
40
+
41
+ @classmethod
42
+ def from_list(cls, items: List[_T]) -> "Stack[_T]":
43
+ stack = cls()
44
+ stack.extend(items)
45
+ return stack
46
+
47
+ def __eq__(self, other: object) -> bool:
48
+ if not isinstance(other, Stack):
49
+ return False
50
+ return self._items == other._items
@@ -80,6 +80,11 @@ class _BaseWorkflowMeta(type):
80
80
  def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
81
81
  if "graph" not in dct:
82
82
  dct["graph"] = set()
83
+ for base in bases:
84
+ base_graph = getattr(base, "graph", None)
85
+ if base_graph:
86
+ dct["graph"] = base_graph
87
+ break
83
88
 
84
89
  if "Outputs" in dct:
85
90
  outputs_class = dct["Outputs"]
@@ -146,7 +151,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
146
151
 
147
152
  WorkflowEvent = Union[ # type: ignore
148
153
  GenericWorkflowEvent,
149
- WorkflowExecutionInitiatedEvent[InputsType], # type: ignore[valid-type]
154
+ WorkflowExecutionInitiatedEvent[InputsType, StateType], # type: ignore[valid-type]
150
155
  WorkflowExecutionFulfilledEvent[Outputs],
151
156
  WorkflowExecutionSnapshottedEvent[StateType], # type: ignore[valid-type]
152
157
  ]
@@ -335,7 +340,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
335
340
 
336
341
  if not last_event:
337
342
  return WorkflowExecutionRejectedEvent(
338
- trace_id=uuid4(),
343
+ trace_id=self._execution_context.trace_id,
339
344
  span_id=uuid4(),
340
345
  body=WorkflowExecutionRejectedBody(
341
346
  error=WorkflowError(
@@ -348,7 +353,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
348
353
 
349
354
  if not first_event:
350
355
  return WorkflowExecutionRejectedEvent(
351
- trace_id=uuid4(),
356
+ trace_id=self._execution_context.trace_id,
352
357
  span_id=uuid4(),
353
358
  body=WorkflowExecutionRejectedBody(
354
359
  error=WorkflowError(
@@ -367,7 +372,7 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
367
372
  return last_event
368
373
 
369
374
  return WorkflowExecutionRejectedEvent(
370
- trace_id=first_event.trace_id,
375
+ trace_id=self._execution_context.trace_id,
371
376
  span_id=first_event.span_id,
372
377
  body=WorkflowExecutionRejectedBody(
373
378
  workflow_definition=self.__class__,
@@ -482,21 +487,11 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
482
487
  return self.get_inputs_class()()
483
488
 
484
489
  def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
485
- execution_context = self._execution_context
486
490
  return self.get_state_class()(
487
- meta=(
488
- StateMeta(
489
- parent=self._parent_state,
490
- workflow_inputs=workflow_inputs or self.get_default_inputs(),
491
- trace_id=execution_context.trace_id,
492
- workflow_definition=self.__class__,
493
- )
494
- if execution_context and int(execution_context.trace_id)
495
- else StateMeta(
496
- parent=self._parent_state,
497
- workflow_inputs=workflow_inputs or self.get_default_inputs(),
498
- workflow_definition=self.__class__,
499
- )
491
+ meta=StateMeta(
492
+ parent=self._parent_state,
493
+ workflow_inputs=workflow_inputs or self.get_default_inputs(),
494
+ workflow_definition=self.__class__,
500
495
  )
501
496
  )
502
497