vellum-ai 0.14.37__py3-none-any.whl → 0.14.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vellum/__init__.py +10 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/reference.md +6272 -0
  4. vellum/client/types/__init__.py +10 -0
  5. vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
  6. vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
  7. vellum/client/types/test_suite_run_exec_config_request.py +4 -0
  8. vellum/client/types/test_suite_run_progress.py +20 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
  10. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
  11. vellum/client/types/test_suite_run_read.py +3 -0
  12. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
  13. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
  14. vellum/client/types/vellum_sdk_error_code_enum.py +1 -0
  15. vellum/client/types/workflow_execution_event_error_code.py +1 -0
  16. vellum/plugins/pydantic.py +1 -1
  17. vellum/types/test_suite_run_progress.py +3 -0
  18. vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
  19. vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
  20. vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
  21. vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
  22. vellum/workflows/errors/types.py +1 -0
  23. vellum/workflows/events/node.py +2 -1
  24. vellum/workflows/events/tests/test_event.py +1 -0
  25. vellum/workflows/events/types.py +3 -40
  26. vellum/workflows/events/workflow.py +15 -4
  27. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +7 -1
  28. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
  29. vellum/workflows/nodes/displayable/conftest.py +2 -6
  30. vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
  31. vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
  32. vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
  33. vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +6 -1
  34. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +323 -0
  35. vellum/workflows/runner/runner.py +78 -57
  36. vellum/workflows/state/base.py +177 -50
  37. vellum/workflows/state/tests/test_state.py +26 -20
  38. vellum/workflows/types/definition.py +71 -0
  39. vellum/workflows/types/generics.py +34 -1
  40. vellum/workflows/workflows/base.py +26 -19
  41. vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
  42. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/METADATA +1 -1
  43. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/RECORD +49 -35
  44. vellum_cli/push.py +2 -3
  45. vellum_cli/tests/test_push.py +52 -0
  46. vellum_ee/workflows/display/vellum.py +0 -5
  47. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/LICENSE +0 -0
  48. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/WHEEL +0 -0
  49. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.39.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  from collections import defaultdict
2
+ from contextlib import contextmanager
2
3
  from copy import deepcopy
3
4
  from dataclasses import field
4
5
  from datetime import datetime
@@ -6,7 +7,7 @@ import logging
6
7
  from queue import Queue
7
8
  from threading import Lock
8
9
  from uuid import UUID, uuid4
9
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
10
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
10
11
  from typing_extensions import dataclass_transform
11
12
 
12
13
  from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
@@ -16,19 +17,14 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
16
17
  from vellum.workflows.constants import undefined
17
18
  from vellum.workflows.edges.edge import Edge
18
19
  from vellum.workflows.inputs.base import BaseInputs
19
- from vellum.workflows.outputs.base import BaseOutputs
20
20
  from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
21
- from vellum.workflows.types.generics import StateType
21
+ from vellum.workflows.types.definition import CodeResourceDefinition, serialize_type_encoder_with_id
22
+ from vellum.workflows.types.generics import StateType, import_workflow_class, is_workflow_class
22
23
  from vellum.workflows.types.stack import Stack
23
- from vellum.workflows.types.utils import (
24
- datetime_now,
25
- deepcopy_with_exclusions,
26
- get_class_attr_names,
27
- get_class_by_qualname,
28
- infer_types,
29
- )
24
+ from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_attr_names, infer_types
30
25
 
31
26
  if TYPE_CHECKING:
27
+ from vellum.workflows import BaseWorkflow
32
28
  from vellum.workflows.nodes.bases import BaseNode
33
29
 
34
30
  logger = logging.getLogger(__name__)
@@ -94,38 +90,65 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
94
90
  return value
95
91
 
96
92
 
93
+ NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
94
+ NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
95
+ NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
96
+ DependenciesInvoked = Dict[UUID, Set[Type["BaseNode"]]]
97
+
98
+
97
99
  class NodeExecutionCache:
98
- _node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
99
- _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
100
- _node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
101
- _dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
100
+ _node_executions_fulfilled: NodeExecutionsFulfilled
101
+ _node_executions_initiated: NodeExecutionsInitiated
102
+ _node_executions_queued: NodeExecutionsQueued
103
+ _dependencies_invoked: DependenciesInvoked
102
104
 
103
- def __init__(
104
- self,
105
- dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
106
- node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
107
- node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
108
- node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
109
- ) -> None:
105
+ def __init__(self) -> None:
110
106
  self._dependencies_invoked = defaultdict(set)
111
107
  self._node_executions_fulfilled = defaultdict(Stack[UUID])
112
108
  self._node_executions_initiated = defaultdict(set)
113
109
  self._node_executions_queued = defaultdict(list)
114
110
 
115
- for execution_id, dependencies in (dependencies_invoked or {}).items():
116
- self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
111
+ @classmethod
112
+ def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
113
+ cache = cls()
114
+
115
+ dependencies_invoked = raw_data.get("dependencies_invoked")
116
+ if isinstance(dependencies_invoked, dict):
117
+ for execution_id, dependencies in dependencies_invoked.items():
118
+ cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
119
+
120
+ node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
121
+ if isinstance(node_executions_fulfilled, dict):
122
+ for node, execution_ids in node_executions_fulfilled.items():
123
+ node_class = nodes.get(node)
124
+ if not node_class:
125
+ continue
117
126
 
118
- for node, execution_ids in (node_executions_fulfilled or {}).items():
119
- node_class = get_class_by_qualname(node)
120
- self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
127
+ cache._node_executions_fulfilled[node_class].extend(
128
+ UUID(execution_id) for execution_id in execution_ids
129
+ )
130
+
131
+ node_executions_initiated = raw_data.get("node_executions_initiated")
132
+ if isinstance(node_executions_initiated, dict):
133
+ for node, execution_ids in node_executions_initiated.items():
134
+ node_class = nodes.get(node)
135
+ if not node_class:
136
+ continue
137
+
138
+ cache._node_executions_initiated[node_class].update(
139
+ {UUID(execution_id) for execution_id in execution_ids}
140
+ )
141
+
142
+ node_executions_queued = raw_data.get("node_executions_queued")
143
+ if isinstance(node_executions_queued, dict):
144
+ for node, execution_ids in node_executions_queued.items():
145
+ node_class = nodes.get(node)
146
+ if not node_class:
147
+ continue
121
148
 
122
- for node, execution_ids in (node_executions_initiated or {}).items():
123
- node_class = get_class_by_qualname(node)
124
- self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
149
+ cache._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
125
150
 
126
- for node, execution_ids in (node_executions_queued or {}).items():
127
- node_class = get_class_by_qualname(node)
128
- self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
151
+ return cache
129
152
 
130
153
  def _invoke_dependency(
131
154
  self,
@@ -207,8 +230,8 @@ def default_datetime_factory() -> datetime:
207
230
 
208
231
 
209
232
  class StateMeta(UniversalBaseModel):
233
+ workflow_definition: Type["BaseWorkflow"] = field(default_factory=import_workflow_class)
210
234
  id: UUID = field(default_factory=uuid4_default_factory)
211
- trace_id: UUID = field(default_factory=uuid4_default_factory)
212
235
  span_id: UUID = field(default_factory=uuid4_default_factory)
213
236
  updated_ts: datetime = field(default_factory=default_datetime_factory)
214
237
  workflow_inputs: BaseInputs = field(default_factory=BaseInputs)
@@ -219,8 +242,6 @@ class StateMeta(UniversalBaseModel):
219
242
  __snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
220
243
 
221
244
  def model_post_init(self, context: Any) -> None:
222
- if self.parent:
223
- self.trace_id = self.parent.meta.trace_id
224
245
  self.__snapshot_callback__ = None
225
246
 
226
247
  def add_snapshot_callback(self, callback: Callable[[], None]) -> None:
@@ -237,6 +258,25 @@ class StateMeta(UniversalBaseModel):
237
258
  if callable(self.__snapshot_callback__):
238
259
  self.__snapshot_callback__()
239
260
 
261
+ @field_serializer("workflow_definition")
262
+ def serialize_workflow_definition(self, workflow_definition: Type["BaseWorkflow"], _info: Any) -> Dict[str, Any]:
263
+ return serialize_type_encoder_with_id(workflow_definition)
264
+
265
+ @field_validator("workflow_definition", mode="before")
266
+ @classmethod
267
+ def deserialize_workflow_definition(cls, workflow_definition: Any, info: ValidationInfo):
268
+ if isinstance(workflow_definition, dict):
269
+ deserialized_workflow_definition = CodeResourceDefinition.model_validate(workflow_definition).decode()
270
+ if not is_workflow_class(deserialized_workflow_definition):
271
+ return import_workflow_class()
272
+
273
+ return deserialized_workflow_definition
274
+
275
+ if is_workflow_class(workflow_definition):
276
+ return workflow_definition
277
+
278
+ return import_workflow_class()
279
+
240
280
  @field_serializer("node_outputs")
241
281
  def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
242
282
  return {str(descriptor): value for descriptor, value in node_outputs.items()}
@@ -244,17 +284,16 @@ class StateMeta(UniversalBaseModel):
244
284
  @field_validator("node_outputs", mode="before")
245
285
  @classmethod
246
286
  def deserialize_node_outputs(cls, node_outputs: Any, info: ValidationInfo):
247
- if isinstance(node_outputs, dict) and isinstance(info.context, dict):
248
- raw_workflow_nodes = info.context.get("nodes")
249
- workflow_node_outputs = {}
250
- if isinstance(raw_workflow_nodes, list):
251
- for node in raw_workflow_nodes:
252
- Outputs = getattr(node, "Outputs", None)
253
- if not isinstance(Outputs, type) or not issubclass(Outputs, BaseOutputs):
254
- continue
287
+ if isinstance(node_outputs, dict):
288
+ workflow_definition = cls._get_workflow(info)
289
+ if not workflow_definition:
290
+ return node_outputs
255
291
 
256
- for output in Outputs:
257
- workflow_node_outputs[str(output)] = output
292
+ raw_workflow_nodes = workflow_definition.get_nodes()
293
+ workflow_node_outputs = {}
294
+ for node in raw_workflow_nodes:
295
+ for output in node.Outputs:
296
+ workflow_node_outputs[str(output)] = output
258
297
 
259
298
  node_output_keys = list(node_outputs.keys())
260
299
  deserialized_node_outputs = {}
@@ -269,12 +308,64 @@ class StateMeta(UniversalBaseModel):
269
308
 
270
309
  return node_outputs
271
310
 
311
+ @field_validator("node_execution_cache", mode="before")
312
+ @classmethod
313
+ def deserialize_node_execution_cache(cls, node_execution_cache: Any, info: ValidationInfo):
314
+ if isinstance(node_execution_cache, dict):
315
+ workflow_definition = cls._get_workflow(info)
316
+ if not workflow_definition:
317
+ return node_execution_cache
318
+
319
+ nodes_cache: Dict[str, Type["BaseNode"]] = {}
320
+ raw_workflow_nodes = workflow_definition.get_nodes()
321
+ for node in raw_workflow_nodes:
322
+ nodes_cache[str(node)] = node
323
+
324
+ return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
325
+
326
+ return node_execution_cache
327
+
328
+ @field_validator("workflow_inputs", mode="before")
329
+ @classmethod
330
+ def deserialize_workflow_inputs(cls, workflow_inputs: Any, info: ValidationInfo):
331
+ workflow_definition = cls._get_workflow(info)
332
+
333
+ if workflow_definition:
334
+ if workflow_inputs is None:
335
+ return workflow_definition.get_inputs_class()()
336
+ if isinstance(workflow_inputs, dict):
337
+ return workflow_definition.get_inputs_class()(**workflow_inputs)
338
+
339
+ return workflow_inputs
340
+
272
341
  @field_serializer("external_inputs")
273
342
  def serialize_external_inputs(
274
343
  self, external_inputs: Dict[ExternalInputReference, Any], _info: Any
275
344
  ) -> Dict[str, Any]:
276
345
  return {str(descriptor): value for descriptor, value in external_inputs.items()}
277
346
 
347
+ @field_validator("parent", mode="before")
348
+ @classmethod
349
+ def deserialize_parent(cls, parent: Any, info: ValidationInfo):
350
+ if isinstance(parent, dict):
351
+ workflow_definition = cls._get_workflow(info)
352
+ if not workflow_definition:
353
+ return parent
354
+
355
+ parent_meta = parent.get("meta")
356
+ if not isinstance(parent_meta, dict):
357
+ return parent
358
+
359
+ parent_workflow_definition = cls.deserialize_workflow_definition(
360
+ parent_meta.get("workflow_definition"), info
361
+ )
362
+ if not is_workflow_class(parent_workflow_definition):
363
+ return parent
364
+
365
+ return parent_workflow_definition.deserialize_state(parent)
366
+
367
+ return parent
368
+
278
369
  def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "StateMeta":
279
370
  if not memo:
280
371
  memo = {}
@@ -295,16 +386,30 @@ class StateMeta(UniversalBaseModel):
295
386
 
296
387
  return super().__deepcopy__(memo)
297
388
 
389
+ @classmethod
390
+ def _get_workflow(cls, info: ValidationInfo) -> Optional[Type["BaseWorkflow"]]:
391
+ if not info.context:
392
+ return None
393
+
394
+ if not isinstance(info.context, dict):
395
+ return None
396
+
397
+ workflow_definition = info.context.get("workflow_definition")
398
+ if not is_workflow_class(workflow_definition):
399
+ return None
400
+
401
+ return workflow_definition
402
+
298
403
 
299
404
  class BaseState(metaclass=_BaseStateMeta):
300
405
  meta: StateMeta = field(init=False)
301
406
 
302
407
  __lock__: Lock = field(init=False)
303
- __is_initializing__: bool = field(init=False)
408
+ __is_quiet__: bool = field(init=False)
304
409
  __snapshot_callback__: Callable[["BaseState"], None] = field(init=False)
305
410
 
306
411
  def __init__(self, meta: Optional[StateMeta] = None, **kwargs: Any) -> None:
307
- self.__is_initializing__ = True
412
+ self.__is_quiet__ = True
308
413
  self.__snapshot_callback__ = lambda state: None
309
414
  self.__lock__ = Lock()
310
415
 
@@ -314,14 +419,14 @@ class BaseState(metaclass=_BaseStateMeta):
314
419
  # Make all class attribute values snapshottable
315
420
  for name, value in self.__class__.__dict__.items():
316
421
  if not name.startswith("_") and name != "meta":
317
- # Bypass __is_initializing__ instead of `setattr`
422
+ # Bypass __is_quiet__ instead of `setattr`
318
423
  snapshottable_value = _make_snapshottable(value, self.__snapshot__)
319
424
  super().__setattr__(name, snapshottable_value)
320
425
 
321
426
  for name, value in kwargs.items():
322
427
  setattr(self, name, value)
323
428
 
324
- self.__is_initializing__ = False
429
+ self.__is_quiet__ = False
325
430
 
326
431
  def __deepcopy__(self, memo: Any) -> "BaseState":
327
432
  new_state = deepcopy_with_exclusions(
@@ -368,7 +473,7 @@ class BaseState(metaclass=_BaseStateMeta):
368
473
  return self.__dict__[key]
369
474
 
370
475
  def __setattr__(self, name: str, value: Any) -> None:
371
- if name.startswith("_") or self.__is_initializing__:
476
+ if name.startswith("_"):
372
477
  super().__setattr__(name, value)
373
478
  return
374
479
 
@@ -409,11 +514,33 @@ class BaseState(metaclass=_BaseStateMeta):
409
514
  Snapshots the current state to the workflow emitter. The invoked callback is overridden by the
410
515
  workflow runner.
411
516
  """
517
+ if self.__is_quiet__:
518
+ return
519
+
412
520
  try:
413
521
  self.__snapshot_callback__(deepcopy(self))
414
522
  except Exception:
415
523
  logger.exception("Failed to snapshot Workflow state.")
416
524
 
525
+ @contextmanager
526
+ def __quiet__(self):
527
+ prev = self.__is_quiet__
528
+ self.__is_quiet__ = True
529
+ try:
530
+ yield
531
+ finally:
532
+ self.__is_quiet__ = prev
533
+
534
+ @contextmanager
535
+ def __atomic__(self):
536
+ prev = self.__is_quiet__
537
+ self.__is_quiet__ = True
538
+ try:
539
+ yield
540
+ finally:
541
+ self.__is_quiet__ = prev
542
+ self.__snapshot__()
543
+
417
544
  @classmethod
418
545
  def __get_pydantic_core_schema__(
419
546
  cls, source_type: Type[Any], handler: GetCoreSchemaHandler
@@ -1,17 +1,14 @@
1
1
  import pytest
2
- from collections import defaultdict
3
2
  from copy import deepcopy
4
3
  import json
5
4
  from queue import Queue
6
- from typing import Dict
5
+ from typing import Dict, cast
7
6
 
8
7
  from vellum.workflows.nodes.bases import BaseNode
9
8
  from vellum.workflows.outputs.base import BaseOutputs
10
9
  from vellum.workflows.state.base import BaseState
11
10
  from vellum.workflows.state.encoder import DefaultStateEncoder
12
11
 
13
- snapshot_count: Dict[int, int] = defaultdict(int)
14
-
15
12
 
16
13
  @pytest.fixture()
17
14
  def mock_deepcopy(mocker):
@@ -27,9 +24,19 @@ class MockState(BaseState):
27
24
  foo: str
28
25
  nested_dict: Dict[str, int] = {}
29
26
 
30
- def __snapshot__(self) -> None:
31
- global snapshot_count
32
- snapshot_count[id(self)] += 1
27
+ __snapshot_count__: int = 0
28
+
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ super().__init__(*args, **kwargs)
31
+ self.__snapshot_callback__ = lambda _: self.__mock_snapshot__()
32
+
33
+ def __mock_snapshot__(self) -> None:
34
+ self.__snapshot_count__ += 1
35
+
36
+ def __deepcopy__(self, memo: dict) -> "MockState":
37
+ new_state = cast(MockState, super().__deepcopy__(memo))
38
+ new_state.__snapshot_count__ = 0
39
+ return new_state
33
40
 
34
41
 
35
42
  class MockNode(BaseNode):
@@ -43,50 +50,50 @@ class MockNode(BaseNode):
43
50
  def test_state_snapshot__node_attribute_edit():
44
51
  # GIVEN an initial state instance
45
52
  state = MockState(foo="bar")
46
- assert snapshot_count[id(state)] == 0
53
+ assert state.__snapshot_count__ == 0
47
54
 
48
55
  # WHEN we edit an attribute
49
56
  state.foo = "baz"
50
57
 
51
58
  # THEN the snapshot is emitted
52
- assert snapshot_count[id(state)] == 1
59
+ assert state.__snapshot_count__ == 1
53
60
 
54
61
 
55
62
  def test_state_snapshot__node_output_edit():
56
63
  # GIVEN an initial state instance
57
64
  state = MockState(foo="bar")
58
- assert snapshot_count[id(state)] == 0
65
+ assert state.__snapshot_count__ == 0
59
66
 
60
67
  # WHEN we add a Node Output to state
61
68
  for output in MockNode.Outputs:
62
69
  state.meta.node_outputs[output] = "hello"
63
70
 
64
71
  # THEN the snapshot is emitted
65
- assert snapshot_count[id(state)] == 1
72
+ assert state.__snapshot_count__ == 1
66
73
 
67
74
 
68
75
  def test_state_snapshot__nested_dictionary_edit():
69
76
  # GIVEN an initial state instance
70
77
  state = MockState(foo="bar")
71
- assert snapshot_count[id(state)] == 0
78
+ assert state.__snapshot_count__ == 0
72
79
 
73
80
  # WHEN we edit a nested dictionary
74
81
  state.nested_dict["hello"] = 1
75
82
 
76
83
  # THEN the snapshot is emitted
77
- assert snapshot_count[id(state)] == 1
84
+ assert state.__snapshot_count__ == 1
78
85
 
79
86
 
80
87
  def test_state_snapshot__external_input_edit():
81
88
  # GIVEN an initial state instance
82
89
  state = MockState(foo="bar")
83
- assert snapshot_count[id(state)] == 0
90
+ assert state.__snapshot_count__ == 0
84
91
 
85
92
  # WHEN we add an external input to state
86
93
  state.meta.external_inputs[MockNode.ExternalInputs.message] = "hello"
87
94
 
88
95
  # THEN the snapshot is emitted
89
- assert snapshot_count[id(state)] == 1
96
+ assert state.__snapshot_count__ == 1
90
97
 
91
98
 
92
99
  def test_state_deepcopy():
@@ -103,7 +110,6 @@ def test_state_deepcopy():
103
110
  assert deepcopied_state.meta.node_outputs == state.meta.node_outputs
104
111
 
105
112
 
106
- @pytest.mark.skip(reason="https://app.shortcut.com/vellum/story/5654")
107
113
  def test_state_deepcopy__with_node_output_updates():
108
114
  # GIVEN an initial state instance
109
115
  state = MockState(foo="bar")
@@ -121,10 +127,10 @@ def test_state_deepcopy__with_node_output_updates():
121
127
  assert deepcopied_state.meta.node_outputs[MockNode.Outputs.baz] == "hello"
122
128
 
123
129
  # AND the original state has had the correct number of snapshots
124
- assert snapshot_count[id(state)] == 2
130
+ assert state.__snapshot_count__ == 2
125
131
 
126
132
  # AND the copied state has had the correct number of snapshots
127
- assert snapshot_count[id(deepcopied_state)] == 0
133
+ assert deepcopied_state.__snapshot_count__ == 0
128
134
 
129
135
 
130
136
  def test_state_json_serialization__with_node_output_updates():
@@ -158,10 +164,10 @@ def test_state_deepcopy__with_external_input_updates():
158
164
  assert deepcopied_state.meta.external_inputs[MockNode.ExternalInputs.message] == "hello"
159
165
 
160
166
  # AND the original state has had the correct number of snapshots
161
- assert snapshot_count[id(state)] == 2
167
+ assert state.__snapshot_count__ == 2
162
168
 
163
169
  # AND the copied state has had the correct number of snapshots
164
- assert snapshot_count[id(deepcopied_state)] == 0
170
+ assert deepcopied_state.__snapshot_count__ == 0
165
171
 
166
172
 
167
173
  def test_state_json_serialization__with_queue():
@@ -0,0 +1,71 @@
1
+ import importlib
2
+ import inspect
3
+ from types import FrameType
4
+ from uuid import UUID
5
+ from typing import Annotated, Any, Dict, Optional, Union
6
+
7
+ from pydantic import BeforeValidator
8
+
9
+ from vellum.client.types.code_resource_definition import CodeResourceDefinition as ClientCodeResourceDefinition
10
+
11
+
12
+ def serialize_type_encoder(obj: type) -> Dict[str, Any]:
13
+ return {
14
+ "name": obj.__name__,
15
+ "module": obj.__module__.split("."),
16
+ }
17
+
18
+
19
+ def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
20
+ if hasattr(obj, "__id__") and isinstance(obj, type):
21
+ return {
22
+ "id": getattr(obj, "__id__"),
23
+ **serialize_type_encoder(obj),
24
+ }
25
+ elif isinstance(obj, CodeResourceDefinition):
26
+ return obj.model_dump(mode="json")
27
+
28
+ raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
29
+
30
+
31
+ class CodeResourceDefinition(ClientCodeResourceDefinition):
32
+ id: UUID
33
+
34
+ @staticmethod
35
+ def encode(obj: type) -> "CodeResourceDefinition":
36
+ return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
37
+
38
+ def decode(self) -> Any:
39
+ if ".<locals>." in self.name:
40
+ # We are decoding a local class that should already be loaded in our stack frame. So
41
+ # we climb up to look for it.
42
+ frame = inspect.currentframe()
43
+ return self._resolve_local(frame)
44
+
45
+ try:
46
+ imported_module = importlib.import_module(".".join(self.module))
47
+ except ImportError:
48
+ return None
49
+
50
+ return getattr(imported_module, self.name, None)
51
+
52
+ def _resolve_local(self, frame: Optional[FrameType]) -> Any:
53
+ if not frame:
54
+ return None
55
+
56
+ frame_module = frame.f_globals.get("__name__")
57
+ if not isinstance(frame_module, str) or frame_module.split(".") != self.module:
58
+ return self._resolve_local(frame.f_back)
59
+
60
+ outer, inner = self.name.split(".<locals>.")
61
+ frame_outer = frame.f_code.co_name
62
+ if frame_outer != outer:
63
+ return self._resolve_local(frame.f_back)
64
+
65
+ return frame.f_locals.get(inner)
66
+
67
+
68
+ VellumCodeResourceDefinition = Annotated[
69
+ CodeResourceDefinition,
70
+ BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
71
+ ]
@@ -1,4 +1,6 @@
1
- from typing import TYPE_CHECKING, TypeVar
1
+ from functools import cache
2
+ from typing import TYPE_CHECKING, Any, Type, TypeVar
3
+ from typing_extensions import TypeGuard
2
4
 
3
5
  if TYPE_CHECKING:
4
6
  from vellum.workflows import BaseWorkflow
@@ -12,3 +14,34 @@ StateType = TypeVar("StateType", bound="BaseState")
12
14
  WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
13
15
  InputsType = TypeVar("InputsType", bound="BaseInputs")
14
16
  OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
17
+
18
+
19
+ @cache
20
+ def _import_node_class() -> Type["BaseNode"]:
21
+ """
22
+ Helper function to help avoid circular imports.
23
+ """
24
+
25
+ from vellum.workflows.nodes import BaseNode
26
+
27
+ return BaseNode
28
+
29
+
30
+ def import_workflow_class() -> Type["BaseWorkflow"]:
31
+ """
32
+ Helper function to help avoid circular imports.
33
+ """
34
+
35
+ from vellum.workflows.workflows import BaseWorkflow
36
+
37
+ return BaseWorkflow
38
+
39
+
40
+ def is_node_class(obj: Any) -> TypeGuard[Type["BaseNode"]]:
41
+ base_node_class = _import_node_class()
42
+ return isinstance(obj, type) and issubclass(obj, base_node_class)
43
+
44
+
45
+ def is_workflow_class(obj: Any) -> TypeGuard[Type["BaseWorkflow"]]:
46
+ base_workflow_class = import_workflow_class()
47
+ return isinstance(obj, type) and issubclass(obj, base_workflow_class)