vellum-ai 0.14.37__py3-none-any.whl → 0.14.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. vellum/__init__.py +8 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/reference.md +6272 -0
  4. vellum/client/types/__init__.py +8 -0
  5. vellum/client/types/ad_hoc_fulfilled_prompt_execution_meta.py +2 -0
  6. vellum/client/types/fulfilled_prompt_execution_meta.py +2 -0
  7. vellum/client/types/test_suite_run_exec_config_request.py +4 -0
  8. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +27 -0
  9. vellum/client/types/test_suite_run_prompt_sandbox_exec_config_request.py +29 -0
  10. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +22 -0
  11. vellum/client/types/test_suite_run_workflow_sandbox_exec_config_request.py +29 -0
  12. vellum/plugins/pydantic.py +1 -1
  13. vellum/types/test_suite_run_prompt_sandbox_exec_config_data_request.py +3 -0
  14. vellum/types/test_suite_run_prompt_sandbox_exec_config_request.py +3 -0
  15. vellum/types/test_suite_run_workflow_sandbox_exec_config_data_request.py +3 -0
  16. vellum/types/test_suite_run_workflow_sandbox_exec_config_request.py +3 -0
  17. vellum/workflows/events/node.py +2 -1
  18. vellum/workflows/events/types.py +3 -40
  19. vellum/workflows/events/workflow.py +2 -1
  20. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +94 -3
  21. vellum/workflows/nodes/displayable/conftest.py +2 -6
  22. vellum/workflows/nodes/displayable/guardrail_node/node.py +1 -1
  23. vellum/workflows/nodes/displayable/guardrail_node/tests/__init__.py +0 -0
  24. vellum/workflows/nodes/displayable/guardrail_node/tests/test_node.py +50 -0
  25. vellum/workflows/nodes/displayable/prompt_deployment_node/tests/test_node.py +297 -0
  26. vellum/workflows/runner/runner.py +44 -43
  27. vellum/workflows/state/base.py +149 -45
  28. vellum/workflows/types/definition.py +71 -0
  29. vellum/workflows/types/generics.py +34 -1
  30. vellum/workflows/workflows/base.py +20 -3
  31. vellum/workflows/workflows/tests/test_base_workflow.py +232 -1
  32. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/METADATA +1 -1
  33. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/RECORD +37 -25
  34. vellum_ee/workflows/display/vellum.py +0 -5
  35. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/LICENSE +0 -0
  36. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/WHEEL +0 -0
  37. {vellum_ai-0.14.37.dist-info → vellum_ai-0.14.38.dist-info}/entry_points.txt +0 -0
@@ -6,7 +6,7 @@ import logging
6
6
  from queue import Queue
7
7
  from threading import Lock
8
8
  from uuid import UUID, uuid4
9
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
9
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
10
10
  from typing_extensions import dataclass_transform
11
11
 
12
12
  from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
@@ -16,19 +16,14 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
16
16
  from vellum.workflows.constants import undefined
17
17
  from vellum.workflows.edges.edge import Edge
18
18
  from vellum.workflows.inputs.base import BaseInputs
19
- from vellum.workflows.outputs.base import BaseOutputs
20
19
  from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
21
- from vellum.workflows.types.generics import StateType
20
+ from vellum.workflows.types.definition import CodeResourceDefinition, serialize_type_encoder_with_id
21
+ from vellum.workflows.types.generics import StateType, import_workflow_class, is_workflow_class
22
22
  from vellum.workflows.types.stack import Stack
23
- from vellum.workflows.types.utils import (
24
- datetime_now,
25
- deepcopy_with_exclusions,
26
- get_class_attr_names,
27
- get_class_by_qualname,
28
- infer_types,
29
- )
23
+ from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_attr_names, infer_types
30
24
 
31
25
  if TYPE_CHECKING:
26
+ from vellum.workflows import BaseWorkflow
32
27
  from vellum.workflows.nodes.bases import BaseNode
33
28
 
34
29
  logger = logging.getLogger(__name__)
@@ -94,38 +89,65 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
94
89
  return value
95
90
 
96
91
 
92
+ NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
93
+ NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
94
+ NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
95
+ DependenciesInvoked = Dict[UUID, Set[Type["BaseNode"]]]
96
+
97
+
97
98
  class NodeExecutionCache:
98
- _node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
99
- _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
100
- _node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
101
- _dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
99
+ _node_executions_fulfilled: NodeExecutionsFulfilled
100
+ _node_executions_initiated: NodeExecutionsInitiated
101
+ _node_executions_queued: NodeExecutionsQueued
102
+ _dependencies_invoked: DependenciesInvoked
102
103
 
103
- def __init__(
104
- self,
105
- dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
106
- node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
107
- node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
108
- node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
109
- ) -> None:
104
+ def __init__(self) -> None:
110
105
  self._dependencies_invoked = defaultdict(set)
111
106
  self._node_executions_fulfilled = defaultdict(Stack[UUID])
112
107
  self._node_executions_initiated = defaultdict(set)
113
108
  self._node_executions_queued = defaultdict(list)
114
109
 
115
- for execution_id, dependencies in (dependencies_invoked or {}).items():
116
- self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
110
+ @classmethod
111
+ def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
112
+ cache = cls()
113
+
114
+ dependencies_invoked = raw_data.get("dependencies_invoked")
115
+ if isinstance(dependencies_invoked, dict):
116
+ for execution_id, dependencies in dependencies_invoked.items():
117
+ cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
118
+
119
+ node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
120
+ if isinstance(node_executions_fulfilled, dict):
121
+ for node, execution_ids in node_executions_fulfilled.items():
122
+ node_class = nodes.get(node)
123
+ if not node_class:
124
+ continue
125
+
126
+ cache._node_executions_fulfilled[node_class].extend(
127
+ UUID(execution_id) for execution_id in execution_ids
128
+ )
129
+
130
+ node_executions_initiated = raw_data.get("node_executions_initiated")
131
+ if isinstance(node_executions_initiated, dict):
132
+ for node, execution_ids in node_executions_initiated.items():
133
+ node_class = nodes.get(node)
134
+ if not node_class:
135
+ continue
136
+
137
+ cache._node_executions_initiated[node_class].update(
138
+ {UUID(execution_id) for execution_id in execution_ids}
139
+ )
117
140
 
118
- for node, execution_ids in (node_executions_fulfilled or {}).items():
119
- node_class = get_class_by_qualname(node)
120
- self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
141
+ node_executions_queued = raw_data.get("node_executions_queued")
142
+ if isinstance(node_executions_queued, dict):
143
+ for node, execution_ids in node_executions_queued.items():
144
+ node_class = nodes.get(node)
145
+ if not node_class:
146
+ continue
121
147
 
122
- for node, execution_ids in (node_executions_initiated or {}).items():
123
- node_class = get_class_by_qualname(node)
124
- self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
148
+ cache._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
125
149
 
126
- for node, execution_ids in (node_executions_queued or {}).items():
127
- node_class = get_class_by_qualname(node)
128
- self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
150
+ return cache
129
151
 
130
152
  def _invoke_dependency(
131
153
  self,
@@ -207,8 +229,8 @@ def default_datetime_factory() -> datetime:
207
229
 
208
230
 
209
231
  class StateMeta(UniversalBaseModel):
232
+ workflow_definition: Type["BaseWorkflow"] = field(default_factory=import_workflow_class)
210
233
  id: UUID = field(default_factory=uuid4_default_factory)
211
- trace_id: UUID = field(default_factory=uuid4_default_factory)
212
234
  span_id: UUID = field(default_factory=uuid4_default_factory)
213
235
  updated_ts: datetime = field(default_factory=default_datetime_factory)
214
236
  workflow_inputs: BaseInputs = field(default_factory=BaseInputs)
@@ -219,8 +241,6 @@ class StateMeta(UniversalBaseModel):
219
241
  __snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
220
242
 
221
243
  def model_post_init(self, context: Any) -> None:
222
- if self.parent:
223
- self.trace_id = self.parent.meta.trace_id
224
244
  self.__snapshot_callback__ = None
225
245
 
226
246
  def add_snapshot_callback(self, callback: Callable[[], None]) -> None:
@@ -237,6 +257,25 @@ class StateMeta(UniversalBaseModel):
237
257
  if callable(self.__snapshot_callback__):
238
258
  self.__snapshot_callback__()
239
259
 
260
+ @field_serializer("workflow_definition")
261
+ def serialize_workflow_definition(self, workflow_definition: Type["BaseWorkflow"], _info: Any) -> Dict[str, Any]:
262
+ return serialize_type_encoder_with_id(workflow_definition)
263
+
264
+ @field_validator("workflow_definition", mode="before")
265
+ @classmethod
266
+ def deserialize_workflow_definition(cls, workflow_definition: Any, info: ValidationInfo):
267
+ if isinstance(workflow_definition, dict):
268
+ deserialized_workflow_definition = CodeResourceDefinition.model_validate(workflow_definition).decode()
269
+ if not is_workflow_class(deserialized_workflow_definition):
270
+ return import_workflow_class()
271
+
272
+ return deserialized_workflow_definition
273
+
274
+ if is_workflow_class(workflow_definition):
275
+ return workflow_definition
276
+
277
+ return import_workflow_class()
278
+
240
279
  @field_serializer("node_outputs")
241
280
  def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
242
281
  return {str(descriptor): value for descriptor, value in node_outputs.items()}
@@ -244,17 +283,16 @@ class StateMeta(UniversalBaseModel):
244
283
  @field_validator("node_outputs", mode="before")
245
284
  @classmethod
246
285
  def deserialize_node_outputs(cls, node_outputs: Any, info: ValidationInfo):
247
- if isinstance(node_outputs, dict) and isinstance(info.context, dict):
248
- raw_workflow_nodes = info.context.get("nodes")
249
- workflow_node_outputs = {}
250
- if isinstance(raw_workflow_nodes, list):
251
- for node in raw_workflow_nodes:
252
- Outputs = getattr(node, "Outputs", None)
253
- if not isinstance(Outputs, type) or not issubclass(Outputs, BaseOutputs):
254
- continue
286
+ if isinstance(node_outputs, dict):
287
+ workflow_definition = cls._get_workflow(info)
288
+ if not workflow_definition:
289
+ return node_outputs
255
290
 
256
- for output in Outputs:
257
- workflow_node_outputs[str(output)] = output
291
+ raw_workflow_nodes = workflow_definition.get_nodes()
292
+ workflow_node_outputs = {}
293
+ for node in raw_workflow_nodes:
294
+ for output in node.Outputs:
295
+ workflow_node_outputs[str(output)] = output
258
296
 
259
297
  node_output_keys = list(node_outputs.keys())
260
298
  deserialized_node_outputs = {}
@@ -269,12 +307,64 @@ class StateMeta(UniversalBaseModel):
269
307
 
270
308
  return node_outputs
271
309
 
310
+ @field_validator("node_execution_cache", mode="before")
311
+ @classmethod
312
+ def deserialize_node_execution_cache(cls, node_execution_cache: Any, info: ValidationInfo):
313
+ if isinstance(node_execution_cache, dict):
314
+ workflow_definition = cls._get_workflow(info)
315
+ if not workflow_definition:
316
+ return node_execution_cache
317
+
318
+ nodes_cache: Dict[str, Type["BaseNode"]] = {}
319
+ raw_workflow_nodes = workflow_definition.get_nodes()
320
+ for node in raw_workflow_nodes:
321
+ nodes_cache[str(node)] = node
322
+
323
+ return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
324
+
325
+ return node_execution_cache
326
+
327
+ @field_validator("workflow_inputs", mode="before")
328
+ @classmethod
329
+ def deserialize_workflow_inputs(cls, workflow_inputs: Any, info: ValidationInfo):
330
+ workflow_definition = cls._get_workflow(info)
331
+
332
+ if workflow_definition:
333
+ if workflow_inputs is None:
334
+ return workflow_definition.get_inputs_class()()
335
+ if isinstance(workflow_inputs, dict):
336
+ return workflow_definition.get_inputs_class()(**workflow_inputs)
337
+
338
+ return workflow_inputs
339
+
272
340
  @field_serializer("external_inputs")
273
341
  def serialize_external_inputs(
274
342
  self, external_inputs: Dict[ExternalInputReference, Any], _info: Any
275
343
  ) -> Dict[str, Any]:
276
344
  return {str(descriptor): value for descriptor, value in external_inputs.items()}
277
345
 
346
+ @field_validator("parent", mode="before")
347
+ @classmethod
348
+ def deserialize_parent(cls, parent: Any, info: ValidationInfo):
349
+ if isinstance(parent, dict):
350
+ workflow_definition = cls._get_workflow(info)
351
+ if not workflow_definition:
352
+ return parent
353
+
354
+ parent_meta = parent.get("meta")
355
+ if not isinstance(parent_meta, dict):
356
+ return parent
357
+
358
+ parent_workflow_definition = cls.deserialize_workflow_definition(
359
+ parent_meta.get("workflow_definition"), info
360
+ )
361
+ if not is_workflow_class(parent_workflow_definition):
362
+ return parent
363
+
364
+ return parent_workflow_definition.deserialize_state(parent)
365
+
366
+ return parent
367
+
278
368
  def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "StateMeta":
279
369
  if not memo:
280
370
  memo = {}
@@ -295,6 +385,20 @@ class StateMeta(UniversalBaseModel):
295
385
 
296
386
  return super().__deepcopy__(memo)
297
387
 
388
+ @classmethod
389
+ def _get_workflow(cls, info: ValidationInfo) -> Optional[Type["BaseWorkflow"]]:
390
+ if not info.context:
391
+ return None
392
+
393
+ if not isinstance(info.context, dict):
394
+ return None
395
+
396
+ workflow_definition = info.context.get("workflow_definition")
397
+ if not is_workflow_class(workflow_definition):
398
+ return None
399
+
400
+ return workflow_definition
401
+
298
402
 
299
403
  class BaseState(metaclass=_BaseStateMeta):
300
404
  meta: StateMeta = field(init=False)
@@ -0,0 +1,71 @@
1
+ import importlib
2
+ import inspect
3
+ from types import FrameType
4
+ from uuid import UUID
5
+ from typing import Annotated, Any, Dict, Optional, Union
6
+
7
+ from pydantic import BeforeValidator
8
+
9
+ from vellum.client.types.code_resource_definition import CodeResourceDefinition as ClientCodeResourceDefinition
10
+
11
+
12
+ def serialize_type_encoder(obj: type) -> Dict[str, Any]:
13
+ return {
14
+ "name": obj.__name__,
15
+ "module": obj.__module__.split("."),
16
+ }
17
+
18
+
19
+ def serialize_type_encoder_with_id(obj: Union[type, "CodeResourceDefinition"]) -> Dict[str, Any]:
20
+ if hasattr(obj, "__id__") and isinstance(obj, type):
21
+ return {
22
+ "id": getattr(obj, "__id__"),
23
+ **serialize_type_encoder(obj),
24
+ }
25
+ elif isinstance(obj, CodeResourceDefinition):
26
+ return obj.model_dump(mode="json")
27
+
28
+ raise AttributeError(f"The object of type '{type(obj).__name__}' must have an '__id__' attribute.")
29
+
30
+
31
+ class CodeResourceDefinition(ClientCodeResourceDefinition):
32
+ id: UUID
33
+
34
+ @staticmethod
35
+ def encode(obj: type) -> "CodeResourceDefinition":
36
+ return CodeResourceDefinition(**serialize_type_encoder_with_id(obj))
37
+
38
+ def decode(self) -> Any:
39
+ if ".<locals>." in self.name:
40
+ # We are decoding a local class that should already be loaded in our stack frame. So
41
+ # we climb up to look for it.
42
+ frame = inspect.currentframe()
43
+ return self._resolve_local(frame)
44
+
45
+ try:
46
+ imported_module = importlib.import_module(".".join(self.module))
47
+ except ImportError:
48
+ return None
49
+
50
+ return getattr(imported_module, self.name, None)
51
+
52
+ def _resolve_local(self, frame: Optional[FrameType]) -> Any:
53
+ if not frame:
54
+ return None
55
+
56
+ frame_module = frame.f_globals.get("__name__")
57
+ if not isinstance(frame_module, str) or frame_module.split(".") != self.module:
58
+ return self._resolve_local(frame.f_back)
59
+
60
+ outer, inner = self.name.split(".<locals>.")
61
+ frame_outer = frame.f_code.co_name
62
+ if frame_outer != outer:
63
+ return self._resolve_local(frame.f_back)
64
+
65
+ return frame.f_locals.get(inner)
66
+
67
+
68
+ VellumCodeResourceDefinition = Annotated[
69
+ CodeResourceDefinition,
70
+ BeforeValidator(lambda d: (d if type(d) is dict else serialize_type_encoder_with_id(d))),
71
+ ]
@@ -1,4 +1,6 @@
1
- from typing import TYPE_CHECKING, TypeVar
1
+ from functools import cache
2
+ from typing import TYPE_CHECKING, Any, Type, TypeVar
3
+ from typing_extensions import TypeGuard
2
4
 
3
5
  if TYPE_CHECKING:
4
6
  from vellum.workflows import BaseWorkflow
@@ -12,3 +14,34 @@ StateType = TypeVar("StateType", bound="BaseState")
12
14
  WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
13
15
  InputsType = TypeVar("InputsType", bound="BaseInputs")
14
16
  OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
17
+
18
+
19
+ @cache
20
+ def _import_node_class() -> Type["BaseNode"]:
21
+ """
22
+ Helper function to help avoid circular imports.
23
+ """
24
+
25
+ from vellum.workflows.nodes import BaseNode
26
+
27
+ return BaseNode
28
+
29
+
30
+ def import_workflow_class() -> Type["BaseWorkflow"]:
31
+ """
32
+ Helper function to help avoid circular imports.
33
+ """
34
+
35
+ from vellum.workflows.workflows import BaseWorkflow
36
+
37
+ return BaseWorkflow
38
+
39
+
40
+ def is_node_class(obj: Any) -> TypeGuard[Type["BaseNode"]]:
41
+ base_node_class = _import_node_class()
42
+ return isinstance(obj, type) and issubclass(obj, base_node_class)
43
+
44
+
45
+ def is_workflow_class(obj: Any) -> TypeGuard[Type["BaseWorkflow"]]:
46
+ base_workflow_class = import_workflow_class()
47
+ return isinstance(obj, type) and issubclass(obj, base_workflow_class)
@@ -22,6 +22,7 @@ from typing import (
22
22
  Union,
23
23
  cast,
24
24
  get_args,
25
+ overload,
25
26
  )
26
27
 
27
28
  from vellum.workflows.edges import Edge
@@ -488,11 +489,13 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
488
489
  parent=self._parent_state,
489
490
  workflow_inputs=workflow_inputs or self.get_default_inputs(),
490
491
  trace_id=execution_context.trace_id,
492
+ workflow_definition=self.__class__,
491
493
  )
492
494
  if execution_context and int(execution_context.trace_id)
493
495
  else StateMeta(
494
496
  parent=self._parent_state,
495
497
  workflow_inputs=workflow_inputs or self.get_default_inputs(),
498
+ workflow_definition=self.__class__,
496
499
  )
497
500
  )
498
501
  )
@@ -530,18 +533,30 @@ class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
530
533
 
531
534
  return most_recent_state_snapshot
532
535
 
536
+ @overload
537
+ @classmethod
538
+ def deserialize_state(cls, state: dict, workflow_inputs: Optional[InputsType] = None) -> StateType: ...
539
+
540
+ @overload
533
541
  @classmethod
534
- def deserialize_state(cls, state: dict, workflow_inputs: Optional[InputsType] = None) -> StateType:
542
+ def deserialize_state(cls, state: None, workflow_inputs: Optional[InputsType] = None) -> None: ...
543
+
544
+ @classmethod
545
+ def deserialize_state(
546
+ cls, state: Optional[dict], workflow_inputs: Optional[InputsType] = None
547
+ ) -> Optional[StateType]:
548
+ if state is None:
549
+ return None
550
+
535
551
  state_class = cls.get_state_class()
536
552
  if "meta" in state:
537
- nodes = list(cls.get_nodes())
538
553
  state["meta"] = StateMeta.model_validate(
539
554
  {
540
555
  **state["meta"],
541
556
  "workflow_inputs": workflow_inputs,
542
557
  },
543
558
  context={
544
- "nodes": nodes,
559
+ "workflow_definition": cls,
545
560
  },
546
561
  )
547
562
 
@@ -601,3 +616,5 @@ NodeExecutionRejectedEvent.model_rebuild()
601
616
  NodeExecutionPausedEvent.model_rebuild()
602
617
  NodeExecutionResumedEvent.model_rebuild()
603
618
  NodeExecutionStreamingEvent.model_rebuild()
619
+
620
+ StateMeta.model_rebuild()