vellum-ai 0.9.16rc2__py3-none-any.whl → 0.9.16rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. vellum/plugins/__init__.py +0 -0
  2. vellum/plugins/pydantic.py +74 -0
  3. vellum/plugins/utils.py +19 -0
  4. vellum/plugins/vellum_mypy.py +639 -3
  5. vellum/workflows/README.md +90 -0
  6. vellum/workflows/__init__.py +5 -0
  7. vellum/workflows/constants.py +43 -0
  8. vellum/workflows/descriptors/__init__.py +0 -0
  9. vellum/workflows/descriptors/base.py +339 -0
  10. vellum/workflows/descriptors/tests/test_utils.py +83 -0
  11. vellum/workflows/descriptors/utils.py +90 -0
  12. vellum/workflows/edges/__init__.py +5 -0
  13. vellum/workflows/edges/edge.py +23 -0
  14. vellum/workflows/emitters/__init__.py +5 -0
  15. vellum/workflows/emitters/base.py +14 -0
  16. vellum/workflows/environment/__init__.py +5 -0
  17. vellum/workflows/environment/environment.py +7 -0
  18. vellum/workflows/errors/__init__.py +6 -0
  19. vellum/workflows/errors/types.py +20 -0
  20. vellum/workflows/events/__init__.py +31 -0
  21. vellum/workflows/events/node.py +125 -0
  22. vellum/workflows/events/tests/__init__.py +0 -0
  23. vellum/workflows/events/tests/test_event.py +216 -0
  24. vellum/workflows/events/types.py +52 -0
  25. vellum/workflows/events/utils.py +5 -0
  26. vellum/workflows/events/workflow.py +139 -0
  27. vellum/workflows/exceptions.py +15 -0
  28. vellum/workflows/expressions/__init__.py +0 -0
  29. vellum/workflows/expressions/accessor.py +52 -0
  30. vellum/workflows/expressions/and_.py +32 -0
  31. vellum/workflows/expressions/begins_with.py +31 -0
  32. vellum/workflows/expressions/between.py +38 -0
  33. vellum/workflows/expressions/coalesce_expression.py +41 -0
  34. vellum/workflows/expressions/contains.py +30 -0
  35. vellum/workflows/expressions/does_not_begin_with.py +31 -0
  36. vellum/workflows/expressions/does_not_contain.py +30 -0
  37. vellum/workflows/expressions/does_not_end_with.py +31 -0
  38. vellum/workflows/expressions/does_not_equal.py +25 -0
  39. vellum/workflows/expressions/ends_with.py +31 -0
  40. vellum/workflows/expressions/equals.py +25 -0
  41. vellum/workflows/expressions/greater_than.py +33 -0
  42. vellum/workflows/expressions/greater_than_or_equal_to.py +33 -0
  43. vellum/workflows/expressions/in_.py +31 -0
  44. vellum/workflows/expressions/is_blank.py +24 -0
  45. vellum/workflows/expressions/is_not_blank.py +24 -0
  46. vellum/workflows/expressions/is_not_null.py +21 -0
  47. vellum/workflows/expressions/is_not_undefined.py +22 -0
  48. vellum/workflows/expressions/is_null.py +21 -0
  49. vellum/workflows/expressions/is_undefined.py +22 -0
  50. vellum/workflows/expressions/less_than.py +33 -0
  51. vellum/workflows/expressions/less_than_or_equal_to.py +33 -0
  52. vellum/workflows/expressions/not_between.py +38 -0
  53. vellum/workflows/expressions/not_in.py +31 -0
  54. vellum/workflows/expressions/or_.py +32 -0
  55. vellum/workflows/graph/__init__.py +3 -0
  56. vellum/workflows/graph/graph.py +131 -0
  57. vellum/workflows/graph/tests/__init__.py +0 -0
  58. vellum/workflows/graph/tests/test_graph.py +437 -0
  59. vellum/workflows/inputs/__init__.py +5 -0
  60. vellum/workflows/inputs/base.py +55 -0
  61. vellum/workflows/logging.py +14 -0
  62. vellum/workflows/nodes/__init__.py +46 -0
  63. vellum/workflows/nodes/bases/__init__.py +7 -0
  64. vellum/workflows/nodes/bases/base.py +332 -0
  65. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +5 -0
  66. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +10 -0
  67. vellum/workflows/nodes/bases/tests/__init__.py +0 -0
  68. vellum/workflows/nodes/bases/tests/test_base_node.py +125 -0
  69. vellum/workflows/nodes/core/__init__.py +16 -0
  70. vellum/workflows/nodes/core/error_node/__init__.py +5 -0
  71. vellum/workflows/nodes/core/error_node/node.py +26 -0
  72. vellum/workflows/nodes/core/inline_subworkflow_node/__init__.py +5 -0
  73. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +73 -0
  74. vellum/workflows/nodes/core/map_node/__init__.py +5 -0
  75. vellum/workflows/nodes/core/map_node/node.py +147 -0
  76. vellum/workflows/nodes/core/map_node/tests/__init__.py +0 -0
  77. vellum/workflows/nodes/core/map_node/tests/test_node.py +65 -0
  78. vellum/workflows/nodes/core/retry_node/__init__.py +5 -0
  79. vellum/workflows/nodes/core/retry_node/node.py +106 -0
  80. vellum/workflows/nodes/core/retry_node/tests/__init__.py +0 -0
  81. vellum/workflows/nodes/core/retry_node/tests/test_node.py +93 -0
  82. vellum/workflows/nodes/core/templating_node/__init__.py +5 -0
  83. vellum/workflows/nodes/core/templating_node/custom_filters.py +12 -0
  84. vellum/workflows/nodes/core/templating_node/exceptions.py +2 -0
  85. vellum/workflows/nodes/core/templating_node/node.py +123 -0
  86. vellum/workflows/nodes/core/templating_node/render.py +55 -0
  87. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  88. vellum/workflows/nodes/core/try_node/__init__.py +5 -0
  89. vellum/workflows/nodes/core/try_node/node.py +110 -0
  90. vellum/workflows/nodes/core/try_node/tests/__init__.py +0 -0
  91. vellum/workflows/nodes/core/try_node/tests/test_node.py +82 -0
  92. vellum/workflows/nodes/displayable/__init__.py +31 -0
  93. vellum/workflows/nodes/displayable/api_node/__init__.py +5 -0
  94. vellum/workflows/nodes/displayable/api_node/node.py +44 -0
  95. vellum/workflows/nodes/displayable/bases/__init__.py +11 -0
  96. vellum/workflows/nodes/displayable/bases/api_node/__init__.py +5 -0
  97. vellum/workflows/nodes/displayable/bases/api_node/node.py +70 -0
  98. vellum/workflows/nodes/displayable/bases/base_prompt_node/__init__.py +5 -0
  99. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +60 -0
  100. vellum/workflows/nodes/displayable/bases/inline_prompt_node/__init__.py +5 -0
  101. vellum/workflows/nodes/displayable/bases/inline_prompt_node/constants.py +13 -0
  102. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +118 -0
  103. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +98 -0
  104. vellum/workflows/nodes/displayable/bases/search_node.py +90 -0
  105. vellum/workflows/nodes/displayable/code_execution_node/__init__.py +5 -0
  106. vellum/workflows/nodes/displayable/code_execution_node/node.py +197 -0
  107. vellum/workflows/nodes/displayable/code_execution_node/tests/__init__.py +0 -0
  108. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/__init__.py +0 -0
  109. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/main.py +3 -0
  110. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +111 -0
  111. vellum/workflows/nodes/displayable/code_execution_node/utils.py +10 -0
  112. vellum/workflows/nodes/displayable/conditional_node/__init__.py +5 -0
  113. vellum/workflows/nodes/displayable/conditional_node/node.py +25 -0
  114. vellum/workflows/nodes/displayable/final_output_node/__init__.py +5 -0
  115. vellum/workflows/nodes/displayable/final_output_node/node.py +43 -0
  116. vellum/workflows/nodes/displayable/guardrail_node/__init__.py +5 -0
  117. vellum/workflows/nodes/displayable/guardrail_node/node.py +97 -0
  118. vellum/workflows/nodes/displayable/inline_prompt_node/__init__.py +5 -0
  119. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +41 -0
  120. vellum/workflows/nodes/displayable/merge_node/__init__.py +5 -0
  121. vellum/workflows/nodes/displayable/merge_node/node.py +10 -0
  122. vellum/workflows/nodes/displayable/prompt_deployment_node/__init__.py +5 -0
  123. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +45 -0
  124. vellum/workflows/nodes/displayable/search_node/__init__.py +5 -0
  125. vellum/workflows/nodes/displayable/search_node/node.py +26 -0
  126. vellum/workflows/nodes/displayable/subworkflow_deployment_node/__init__.py +5 -0
  127. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +156 -0
  128. vellum/workflows/nodes/displayable/tests/__init__.py +0 -0
  129. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +148 -0
  130. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +134 -0
  131. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +80 -0
  132. vellum/workflows/nodes/utils.py +27 -0
  133. vellum/workflows/outputs/__init__.py +6 -0
  134. vellum/workflows/outputs/base.py +196 -0
  135. vellum/workflows/ports/__init__.py +7 -0
  136. vellum/workflows/ports/node_ports.py +75 -0
  137. vellum/workflows/ports/port.py +75 -0
  138. vellum/workflows/ports/utils.py +40 -0
  139. vellum/workflows/references/__init__.py +17 -0
  140. vellum/workflows/references/environment_variable.py +20 -0
  141. vellum/workflows/references/execution_count.py +20 -0
  142. vellum/workflows/references/external_input.py +49 -0
  143. vellum/workflows/references/input.py +7 -0
  144. vellum/workflows/references/lazy.py +55 -0
  145. vellum/workflows/references/node.py +43 -0
  146. vellum/workflows/references/output.py +78 -0
  147. vellum/workflows/references/state_value.py +23 -0
  148. vellum/workflows/references/vellum_secret.py +15 -0
  149. vellum/workflows/references/workflow_input.py +41 -0
  150. vellum/workflows/resolvers/__init__.py +5 -0
  151. vellum/workflows/resolvers/base.py +15 -0
  152. vellum/workflows/runner/__init__.py +5 -0
  153. vellum/workflows/runner/runner.py +588 -0
  154. vellum/workflows/runner/types.py +18 -0
  155. vellum/workflows/state/__init__.py +5 -0
  156. vellum/workflows/state/base.py +327 -0
  157. vellum/workflows/state/context.py +18 -0
  158. vellum/workflows/state/encoder.py +57 -0
  159. vellum/workflows/state/store.py +28 -0
  160. vellum/workflows/state/tests/__init__.py +0 -0
  161. vellum/workflows/state/tests/test_state.py +113 -0
  162. vellum/workflows/types/__init__.py +0 -0
  163. vellum/workflows/types/core.py +91 -0
  164. vellum/workflows/types/generics.py +14 -0
  165. vellum/workflows/types/stack.py +39 -0
  166. vellum/workflows/types/tests/__init__.py +0 -0
  167. vellum/workflows/types/tests/test_utils.py +76 -0
  168. vellum/workflows/types/utils.py +164 -0
  169. vellum/workflows/utils/__init__.py +0 -0
  170. vellum/workflows/utils/names.py +13 -0
  171. vellum/workflows/utils/tests/__init__.py +0 -0
  172. vellum/workflows/utils/tests/test_names.py +15 -0
  173. vellum/workflows/utils/tests/test_vellum_variables.py +25 -0
  174. vellum/workflows/utils/vellum_variables.py +81 -0
  175. vellum/workflows/vellum_client.py +18 -0
  176. vellum/workflows/workflows/__init__.py +5 -0
  177. vellum/workflows/workflows/base.py +365 -0
  178. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/METADATA +2 -1
  179. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/RECORD +245 -7
  180. vellum_cli/__init__.py +72 -0
  181. vellum_cli/aliased_group.py +103 -0
  182. vellum_cli/config.py +96 -0
  183. vellum_cli/image_push.py +112 -0
  184. vellum_cli/logger.py +36 -0
  185. vellum_cli/pull.py +73 -0
  186. vellum_cli/push.py +121 -0
  187. vellum_cli/tests/test_config.py +100 -0
  188. vellum_cli/tests/test_pull.py +152 -0
  189. vellum_ee/workflows/__init__.py +0 -0
  190. vellum_ee/workflows/display/__init__.py +0 -0
  191. vellum_ee/workflows/display/base.py +73 -0
  192. vellum_ee/workflows/display/nodes/__init__.py +4 -0
  193. vellum_ee/workflows/display/nodes/base_node_display.py +116 -0
  194. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +36 -0
  195. vellum_ee/workflows/display/nodes/get_node_display_class.py +25 -0
  196. vellum_ee/workflows/display/nodes/tests/__init__.py +0 -0
  197. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +47 -0
  198. vellum_ee/workflows/display/nodes/types.py +18 -0
  199. vellum_ee/workflows/display/nodes/utils.py +33 -0
  200. vellum_ee/workflows/display/nodes/vellum/__init__.py +32 -0
  201. vellum_ee/workflows/display/nodes/vellum/api_node.py +205 -0
  202. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +71 -0
  203. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +217 -0
  204. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +61 -0
  205. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +49 -0
  206. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +170 -0
  207. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +99 -0
  208. vellum_ee/workflows/display/nodes/vellum/map_node.py +100 -0
  209. vellum_ee/workflows/display/nodes/vellum/merge_node.py +48 -0
  210. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +68 -0
  211. vellum_ee/workflows/display/nodes/vellum/search_node.py +193 -0
  212. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +58 -0
  213. vellum_ee/workflows/display/nodes/vellum/templating_node.py +67 -0
  214. vellum_ee/workflows/display/nodes/vellum/tests/__init__.py +0 -0
  215. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +106 -0
  216. vellum_ee/workflows/display/nodes/vellum/try_node.py +38 -0
  217. vellum_ee/workflows/display/nodes/vellum/utils.py +76 -0
  218. vellum_ee/workflows/display/tests/__init__.py +0 -0
  219. vellum_ee/workflows/display/tests/workflow_serialization/__init__.py +0 -0
  220. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +426 -0
  221. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +607 -0
  222. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +1175 -0
  223. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +235 -0
  224. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +511 -0
  225. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +372 -0
  226. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +272 -0
  227. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +289 -0
  228. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +354 -0
  229. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +123 -0
  230. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +84 -0
  231. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +233 -0
  232. vellum_ee/workflows/display/types.py +46 -0
  233. vellum_ee/workflows/display/utils/__init__.py +0 -0
  234. vellum_ee/workflows/display/utils/tests/__init__.py +0 -0
  235. vellum_ee/workflows/display/utils/tests/test_uuids.py +16 -0
  236. vellum_ee/workflows/display/utils/uuids.py +24 -0
  237. vellum_ee/workflows/display/utils/vellum.py +121 -0
  238. vellum_ee/workflows/display/vellum.py +357 -0
  239. vellum_ee/workflows/display/workflows/__init__.py +5 -0
  240. vellum_ee/workflows/display/workflows/base_workflow_display.py +302 -0
  241. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +32 -0
  242. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +386 -0
  243. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/LICENSE +0 -0
  244. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/WHEEL +0 -0
  245. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,327 @@
1
+ from collections import defaultdict
2
+ from copy import deepcopy
3
+ from dataclasses import field
4
+ from datetime import datetime
5
+ from queue import Queue
6
+ from threading import Lock
7
+ from uuid import UUID, uuid4
8
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
9
+ from typing_extensions import dataclass_transform
10
+
11
+ from pydantic import GetCoreSchemaHandler, field_serializer
12
+ from pydantic_core import core_schema
13
+
14
+ from vellum.core.pydantic_utilities import UniversalBaseModel
15
+
16
+ from vellum.workflows.constants import UNDEF
17
+ from vellum.workflows.inputs.base import BaseInputs
18
+ from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
19
+ from vellum.workflows.types.generics import StateType
20
+ from vellum.workflows.types.stack import Stack
21
+ from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_by_qualname, infer_types
22
+
23
+ if TYPE_CHECKING:
24
+ from vellum.workflows.nodes.bases import BaseNode
25
+
26
+
27
+ class _Snapshottable:
28
+ _snapshot_callback: Callable[[], None]
29
+
30
+ def __deepcopy__(self, memo: Any) -> "_Snapshottable":
31
+ return deepcopy_with_exclusions(
32
+ self,
33
+ memo=memo,
34
+ exclusions={
35
+ "_snapshot_callback": self._snapshot_callback,
36
+ },
37
+ )
38
+
39
+
40
+ @dataclass_transform(kw_only_default=True)
41
+ class _BaseStateMeta(type):
42
+ def __getattribute__(cls, name: str) -> Any:
43
+ if not name.startswith("_"):
44
+ instance = vars(cls).get(name)
45
+ types = infer_types(cls, name)
46
+ return StateValueReference(name=name, types=types, instance=instance)
47
+
48
+ return super().__getattribute__(name)
49
+
50
+
51
+ class _SnapshottableDict(dict, _Snapshottable):
52
+ def __setitem__(self, key: Any, value: Any) -> None:
53
+ super().__setitem__(key, value)
54
+ self._snapshot_callback()
55
+
56
+
57
+ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> Any:
58
+ """
59
+ Edits any value to make it snapshottable on edit. Made as a separate function from `BaseState` to
60
+ avoid namespace conflicts with subclasses.
61
+ """
62
+ if isinstance(value, _Snapshottable):
63
+ return value
64
+
65
+ if isinstance(value, dict):
66
+ snapshottable_dict = _SnapshottableDict(value)
67
+ snapshottable_dict._snapshot_callback = snapshot_callback
68
+ return snapshottable_dict
69
+
70
+ return value
71
+
72
+
73
+ class NodeExecutionCache:
74
+ _node_execution_ids: Dict[Type["BaseNode"], Stack[UUID]]
75
+ _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
76
+ _dependencies_invoked: Dict[str, Set[str]]
77
+
78
+ def __init__(
79
+ self,
80
+ dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
81
+ node_execution_ids: Optional[Dict[str, Sequence[str]]] = None,
82
+ node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
83
+ ) -> None:
84
+ self._dependencies_invoked = defaultdict(set)
85
+ self._node_execution_ids = defaultdict(Stack[UUID])
86
+ self._node_executions_initiated = defaultdict(set)
87
+
88
+ for node, dependencies in (dependencies_invoked or {}).items():
89
+ self._dependencies_invoked[node].update(dependencies)
90
+
91
+ for node, execution_ids in (node_execution_ids or {}).items():
92
+ node_class = get_class_by_qualname(node)
93
+ self._node_execution_ids[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
94
+
95
+ for node, execution_ids in (node_executions_initiated or {}).items():
96
+ node_class = get_class_by_qualname(node)
97
+ self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
98
+
99
+ @property
100
+ def dependencies_invoked(self) -> Dict[str, Set[str]]:
101
+ return self._dependencies_invoked
102
+
103
+ def is_node_initiated(self, node: Type["BaseNode"]) -> bool:
104
+ return node in self._node_executions_initiated and len(self._node_executions_initiated[node]) > 0
105
+
106
+ def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
107
+ self._node_executions_initiated[node].add(execution_id)
108
+
109
+ def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
110
+ self._node_executions_initiated[node].remove(execution_id)
111
+ self._node_execution_ids[node].push(execution_id)
112
+
113
+ def get_execution_count(self, node: Type["BaseNode"]) -> int:
114
+ return self._node_execution_ids[node].size()
115
+
116
+ def dump(self) -> Dict[str, Any]:
117
+ return {
118
+ "dependencies_invoked": {
119
+ node: list(dependencies) for node, dependencies in self._dependencies_invoked.items()
120
+ },
121
+ "node_executions_initiated": {
122
+ str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
123
+ },
124
+ "node_execution_ids": {
125
+ str(node): execution_ids.dump() for node, execution_ids in self._node_execution_ids.items()
126
+ },
127
+ }
128
+
129
+ @classmethod
130
+ def __get_pydantic_core_schema__(
131
+ cls, source_type: Type[Any], handler: GetCoreSchemaHandler
132
+ ) -> core_schema.CoreSchema:
133
+ return core_schema.is_instance_schema(cls)
134
+
135
+
136
+ def uuid4_default_factory() -> UUID:
137
+ """
138
+ Allows us to mock the uuid4 for testing.
139
+ """
140
+ return uuid4()
141
+
142
+
143
+ def default_datetime_factory() -> datetime:
144
+ """
145
+ Makes it possible to mock the datetime factory for testing.
146
+ """
147
+
148
+ return datetime_now()
149
+
150
+
151
+ class StateMeta(UniversalBaseModel):
152
+ id: UUID = field(default_factory=uuid4_default_factory)
153
+ trace_id: UUID = field(default_factory=uuid4_default_factory)
154
+ span_id: UUID = field(default_factory=uuid4_default_factory)
155
+ updated_ts: datetime = field(default_factory=default_datetime_factory)
156
+ workflow_inputs: BaseInputs = field(default_factory=BaseInputs)
157
+ external_inputs: Dict[ExternalInputReference, Any] = field(default_factory=dict)
158
+ node_outputs: Dict[OutputReference, Any] = field(default_factory=dict)
159
+ node_execution_cache: NodeExecutionCache = field(default_factory=NodeExecutionCache)
160
+ parent: Optional["BaseState"] = None
161
+ is_terminated: Optional[bool] = None
162
+ __snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
163
+
164
+ def model_post_init(self, context: Any) -> None:
165
+ if self.parent:
166
+ self.trace_id = self.parent.meta.trace_id
167
+ self.__snapshot_callback__ = None
168
+
169
+ def add_snapshot_callback(self, callback: Callable[[], None]) -> None:
170
+ self.node_outputs = _make_snapshottable(self.node_outputs, callback)
171
+ self.__snapshot_callback__ = callback
172
+
173
+ def __setattr__(self, name: str, value: Any) -> None:
174
+ if name.startswith("__") or name == "updated_ts":
175
+ super().__setattr__(name, value)
176
+ return
177
+
178
+ super().__setattr__(name, value)
179
+ if callable(self.__snapshot_callback__):
180
+ self.__snapshot_callback__()
181
+
182
+ @field_serializer("node_outputs")
183
+ def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
184
+ return {str(descriptor): value for descriptor, value in node_outputs.items()}
185
+
186
+ @field_serializer("external_inputs")
187
+ def serialize_external_inputs(
188
+ self, external_inputs: Dict[ExternalInputReference, Any], _info: Any
189
+ ) -> Dict[str, Any]:
190
+ return {str(descriptor): value for descriptor, value in external_inputs.items()}
191
+
192
+ def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "StateMeta":
193
+ if not memo:
194
+ memo = {}
195
+
196
+ new_node_outputs = {
197
+ descriptor: value if isinstance(value, Queue) else deepcopy(value, memo)
198
+ for descriptor, value in self.node_outputs.items()
199
+ }
200
+
201
+ memo[id(self.node_outputs)] = new_node_outputs
202
+ memo[id(self.__snapshot_callback__)] = None
203
+
204
+ return super().__deepcopy__(memo)
205
+
206
+
207
+ class BaseState(metaclass=_BaseStateMeta):
208
+ meta: StateMeta = field(init=False)
209
+
210
+ __lock__: Lock = field(init=False)
211
+ __is_initializing__: bool = field(init=False)
212
+ __snapshot_callback__: Callable[["BaseState"], None] = field(init=False)
213
+
214
+ def __init__(self, meta: Optional[StateMeta] = None, **kwargs: Any) -> None:
215
+ self.__is_initializing__ = True
216
+ self.__snapshot_callback__ = lambda state: None
217
+ self.__lock__ = Lock()
218
+
219
+ self.meta = meta or StateMeta()
220
+ self.meta.add_snapshot_callback(self.__snapshot__)
221
+
222
+ # Make all class attribute values snapshottable
223
+ for name, value in self.__class__.__dict__.items():
224
+ if not name.startswith("_") and name != "meta":
225
+ # Bypass __is_initializing__ instead of `setattr`
226
+ snapshottable_value = _make_snapshottable(value, self.__snapshot__)
227
+ super().__setattr__(name, snapshottable_value)
228
+
229
+ for name, value in kwargs.items():
230
+ setattr(self, name, value)
231
+
232
+ self.__is_initializing__ = False
233
+
234
+ def __deepcopy__(self, memo: Any) -> "BaseState":
235
+ new_state = deepcopy_with_exclusions(
236
+ self,
237
+ exclusions={
238
+ "__lock__": Lock(),
239
+ },
240
+ memo=memo,
241
+ )
242
+ new_state.meta.add_snapshot_callback(new_state.__snapshot__)
243
+ return new_state
244
+
245
+ def __repr__(self) -> str:
246
+ values = "\n".join(
247
+ [f" {key}={value}" for key, value in vars(self).items() if not key.startswith("_") and key != "meta"]
248
+ )
249
+ node_outputs = "\n".join([f" {key}={value}" for key, value in self.meta.node_outputs.items()])
250
+ return f"""\
251
+ {self.__class__.__name__}:
252
+ {values}
253
+ meta:
254
+ id={self.meta.id}
255
+ is_terminated={self.meta.is_terminated}
256
+ updated_ts={self.meta.updated_ts}
257
+ node_outputs:{' Empty' if not node_outputs else ''}
258
+ {node_outputs}
259
+ """
260
+
261
+ def __iter__(self) -> Iterator[Tuple[Any, Any]]:
262
+ """
263
+ Returns an iterator treating all state keys as (key, value) items, allowing consumers to call `dict()`
264
+ on an instance of this class.
265
+ """
266
+
267
+ # If the user sets a default value on state (e.g. something = "foo"), it's not on `instance_attributes` below.
268
+ # So we need to include class_attributes here just in case
269
+ class_attributes = {key: value for key, value in self.__class__.__dict__.items() if not key.startswith("_")}
270
+ instance_attributes = {key: value for key, value in self.__dict__.items() if not key.startswith("__")}
271
+
272
+ all_attributes = {**class_attributes, **instance_attributes}
273
+ items = [(key, value) for key, value in all_attributes.items() if key not in ["_lock"]]
274
+ return iter(items)
275
+
276
+ def __getitem__(self, key: str) -> Any:
277
+ return self.__dict__[key]
278
+
279
+ def __setattr__(self, name: str, value: Any) -> None:
280
+ if name.startswith("_") or self.__is_initializing__:
281
+ super().__setattr__(name, value)
282
+ return
283
+
284
+ snapshottable_value = _make_snapshottable(value, self.__snapshot__)
285
+ super().__setattr__(name, snapshottable_value)
286
+ self.meta.updated_ts = datetime_now()
287
+ self.__snapshot__()
288
+
289
+ def __add__(self, other: StateType) -> StateType:
290
+ """
291
+ Handles merging two states together, preferring the latest state by updated_ts for any given node output.
292
+ """
293
+
294
+ if not isinstance(other, type(self)):
295
+ raise TypeError(f"Cannot add {type(other).__name__} to {type(self).__name__}]")
296
+
297
+ latest_state = self if self.meta.updated_ts >= other.meta.updated_ts else other
298
+ oldest_state = other if latest_state == self else self
299
+
300
+ for descriptor, value in oldest_state.meta.node_outputs.items():
301
+ if descriptor not in latest_state.meta.node_outputs:
302
+ latest_state.meta.node_outputs[descriptor] = value
303
+
304
+ for key, value in oldest_state:
305
+ if not isinstance(key, str):
306
+ continue
307
+
308
+ if key.startswith("_"):
309
+ continue
310
+
311
+ if getattr(latest_state, key, UNDEF) == UNDEF:
312
+ setattr(latest_state, key, value)
313
+
314
+ return cast(StateType, latest_state)
315
+
316
+ def __snapshot__(self) -> None:
317
+ """
318
+ Snapshots the current state to the workflow emitter. The invoked callback is overridden by the
319
+ workflow runner.
320
+ """
321
+ self.__snapshot_callback__(deepcopy(self))
322
+
323
+ @classmethod
324
+ def __get_pydantic_core_schema__(
325
+ cls, source_type: Type[Any], handler: GetCoreSchemaHandler
326
+ ) -> core_schema.CoreSchema:
327
+ return core_schema.is_instance_schema(cls)
@@ -0,0 +1,18 @@
1
+ from functools import cached_property
2
+ from typing import Optional
3
+
4
+ from vellum import Vellum
5
+
6
+ from vellum.workflows.vellum_client import create_vellum_client
7
+
8
+
9
+ class WorkflowContext:
10
+ def __init__(self, _vellum_client: Optional[Vellum] = None):
11
+ self._vellum_client = _vellum_client
12
+
13
+ @cached_property
14
+ def vellum_client(self) -> Vellum:
15
+ if self._vellum_client:
16
+ return self._vellum_client
17
+
18
+ return create_vellum_client()
@@ -0,0 +1,57 @@
1
+ from dataclasses import asdict, is_dataclass
2
+ from datetime import datetime
3
+ import enum
4
+ from json import JSONEncoder
5
+ from uuid import UUID
6
+ from typing import Any, Callable, Dict, Type
7
+
8
+ from pydantic import BaseModel
9
+
10
+ from vellum.workflows.inputs.base import BaseInputs
11
+ from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
12
+ from vellum.workflows.state.base import BaseState, NodeExecutionCache
13
+
14
+
15
+ class DefaultStateEncoder(JSONEncoder):
16
+ encoders: Dict[Type, Callable] = {}
17
+
18
+ def default(self, obj: Any) -> Any:
19
+ if isinstance(obj, BaseState):
20
+ return dict(obj)
21
+
22
+ if isinstance(obj, (BaseInputs, BaseOutputs)):
23
+ return {descriptor.name: value for descriptor, value in obj}
24
+
25
+ if isinstance(obj, BaseOutput):
26
+ return obj.serialize()
27
+
28
+ if isinstance(obj, NodeExecutionCache):
29
+ return obj.dump()
30
+
31
+ if isinstance(obj, UUID):
32
+ return str(obj)
33
+
34
+ if isinstance(obj, set):
35
+ return list(obj)
36
+
37
+ if isinstance(obj, BaseModel):
38
+ return obj.model_dump()
39
+
40
+ if isinstance(obj, datetime):
41
+ return obj.isoformat()
42
+
43
+ if isinstance(obj, enum.Enum):
44
+ return obj.value
45
+
46
+ if is_dataclass(obj):
47
+ # Technically, obj is DataclassInstance | type[DataclassInstance], but asdict expects a DataclassInstance
48
+ # in practice, we only ever pass the former
49
+ return asdict(obj) # type: ignore[call-overload]
50
+
51
+ if isinstance(obj, type):
52
+ return str(obj)
53
+
54
+ if obj.__class__ in self.encoders:
55
+ return self.encoders[obj.__class__](obj)
56
+
57
+ return super().default(obj)
@@ -0,0 +1,28 @@
1
+ from typing import Iterator, List
2
+
3
+ from vellum.workflows.events.workflow import WorkflowEvent
4
+ from vellum.workflows.state.base import BaseState
5
+
6
+
7
+ class Store:
8
+ def __init__(self) -> None:
9
+ self._events: List[WorkflowEvent] = []
10
+ self._state_snapshots: List[BaseState] = []
11
+
12
+ def append_event(self, event: WorkflowEvent) -> None:
13
+ self._events.append(event)
14
+
15
+ def append_state_snapshot(self, state: BaseState) -> None:
16
+ self._state_snapshots.append(state)
17
+
18
+ def clear(self) -> None:
19
+ self._events = []
20
+ self._state_snapshots = []
21
+
22
+ @property
23
+ def events(self) -> Iterator[WorkflowEvent]:
24
+ return iter(self._events)
25
+
26
+ @property
27
+ def state_snapshots(self) -> Iterator[BaseState]:
28
+ return iter(self._state_snapshots)
File without changes
@@ -0,0 +1,113 @@
1
+ from collections import defaultdict
2
+ from copy import deepcopy
3
+ import json
4
+ from typing import Dict
5
+
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.outputs.base import BaseOutputs
8
+ from vellum.workflows.state.base import BaseState
9
+ from vellum.workflows.state.encoder import DefaultStateEncoder
10
+
11
+ snapshot_count: Dict[int, int] = defaultdict(int)
12
+
13
+
14
+ class MockState(BaseState):
15
+ foo: str
16
+ nested_dict: Dict[str, int] = {}
17
+
18
+ def __snapshot__(self) -> None:
19
+ global snapshot_count
20
+ snapshot_count[id(self)] += 1
21
+
22
+
23
+ class MockNode(BaseNode):
24
+ class Outputs(BaseOutputs):
25
+ baz: str
26
+
27
+
28
+ def test_state_snapshot__node_attribute_edit():
29
+ # GIVEN an initial state instance
30
+ state = MockState(foo="bar")
31
+ assert snapshot_count[id(state)] == 0
32
+
33
+ # WHEN we edit an attribute
34
+ state.foo = "baz"
35
+
36
+ # THEN the snapshot is emitted
37
+ assert snapshot_count[id(state)] == 1
38
+
39
+
40
+ def test_state_snapshot__node_output_edit():
41
+ # GIVEN an initial state instance
42
+ state = MockState(foo="bar")
43
+ assert snapshot_count[id(state)] == 0
44
+
45
+ # WHEN we add a Node Output to state
46
+ for output in MockNode.Outputs:
47
+ state.meta.node_outputs[output] = "hello"
48
+
49
+ # THEN the snapshot is emitted
50
+ assert snapshot_count[id(state)] == 1
51
+
52
+
53
+ def test_state_snapshot__nested_dictionary_edit():
54
+ # GIVEN an initial state instance
55
+ state = MockState(foo="bar")
56
+ assert snapshot_count[id(state)] == 0
57
+
58
+ # WHEN we edit a nested dictionary
59
+ state.nested_dict["hello"] = 1
60
+
61
+ # THEN the snapshot is emitted
62
+ assert snapshot_count[id(state)] == 1
63
+
64
+
65
+ def test_state_deepcopy():
66
+ # GIVEN an initial state instance
67
+ state = MockState(foo="bar")
68
+
69
+ # AND we add a Node Output to state
70
+ state.meta.node_outputs[MockNode.Outputs.baz] = "hello"
71
+
72
+ # WHEN we deepcopy the state
73
+ deepcopied_state = deepcopy(state)
74
+
75
+ # THEN node outputs are deepcopied
76
+ assert deepcopied_state.meta.node_outputs == state.meta.node_outputs
77
+
78
+
79
+ def test_state_deepcopy__with_node_output_updates():
80
+ # GIVEN an initial state instance
81
+ state = MockState(foo="bar")
82
+
83
+ # AND we add a Node Output to state
84
+ state.meta.node_outputs[MockNode.Outputs.baz] = "hello"
85
+
86
+ # AND we deepcopy the state
87
+ deepcopied_state = deepcopy(state)
88
+
89
+ # AND we update the original state
90
+ state.meta.node_outputs[MockNode.Outputs.baz] = "world"
91
+
92
+ # THEN the copied state is not updated
93
+ assert deepcopied_state.meta.node_outputs[MockNode.Outputs.baz] == "hello"
94
+
95
+ # AND the original state has had the correct number of snapshots
96
+ assert snapshot_count[id(state)] == 2
97
+
98
+ # AND the copied state has had the correct number of snapshots
99
+ assert snapshot_count[id(deepcopied_state)] == 0
100
+
101
+
102
+ def test_state_json_serialization__with_node_output_updates():
103
+ # GIVEN an initial state instance
104
+ state = MockState(foo="bar")
105
+
106
+ # AND we add a Node Output to state
107
+ state.meta.node_outputs[MockNode.Outputs.baz] = "hello"
108
+
109
+ # WHEN we serialize the state
110
+ json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
111
+
112
+ # THEN the state is serialized correctly
113
+ assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": "hello"}
File without changes
@@ -0,0 +1,91 @@
1
+ from enum import Enum
2
+ from typing import ( # type: ignore[attr-defined]
3
+ Dict,
4
+ List,
5
+ Union,
6
+ _GenericAlias,
7
+ _SpecialGenericAlias,
8
+ _UnionGenericAlias,
9
+ )
10
+
11
+ from vellum import (
12
+ ChatMessage,
13
+ FunctionCall,
14
+ FunctionCallRequest,
15
+ SearchResult,
16
+ SearchResultRequest,
17
+ VellumAudio,
18
+ VellumAudioRequest,
19
+ VellumError,
20
+ VellumErrorRequest,
21
+ VellumImage,
22
+ VellumImageRequest,
23
+ VellumValue,
24
+ VellumValueRequest,
25
+ )
26
+
27
+ JsonArray = List["Json"]
28
+ JsonObject = Dict[str, "Json"]
29
+ Json = Union[None, bool, int, float, str, JsonArray, JsonObject]
30
+
31
+ # Unions and Generics inherit from `_GenericAlias` instead of `type`
32
+ # In future versions of python, we'll see `_UnionGenericAlias`
33
+ UnderGenericAlias = _GenericAlias
34
+ SpecialGenericAlias = _SpecialGenericAlias
35
+ UnionGenericAlias = _UnionGenericAlias
36
+
37
+
38
+ class VellumSecret:
39
+ name: str
40
+
41
+ def __init__(self, name: str):
42
+ self.name = name
43
+
44
+
45
+ VellumValuePrimitive = Union[
46
+ # String inputs
47
+ str,
48
+ # Chat history inputs
49
+ List[ChatMessage],
50
+ List[ChatMessage],
51
+ # Search results inputs
52
+ List[SearchResultRequest],
53
+ List[SearchResult],
54
+ # JSON inputs
55
+ Json,
56
+ # Number inputs
57
+ float,
58
+ # Function Call Inputs
59
+ FunctionCall,
60
+ FunctionCallRequest,
61
+ # Error Inputs
62
+ VellumError,
63
+ VellumErrorRequest,
64
+ # Array Inputs
65
+ List[VellumValueRequest],
66
+ List[VellumValue],
67
+ # Image Inputs
68
+ VellumImage,
69
+ VellumImageRequest,
70
+ # Audio Inputs
71
+ VellumAudio,
72
+ VellumAudioRequest,
73
+ # Vellum Secrets
74
+ VellumSecret,
75
+ ]
76
+
77
+ EntityInputsInterface = Dict[
78
+ str,
79
+ VellumValuePrimitive,
80
+ ]
81
+
82
+
83
+ class MergeBehavior(Enum):
84
+ AWAIT_ALL = "AWAIT_ALL"
85
+ AWAIT_ANY = "AWAIT_ANY"
86
+
87
+
88
+ class ConditionType(Enum):
89
+ IF = "IF"
90
+ ELIF = "ELIF"
91
+ ELSE = "ELSE"
@@ -0,0 +1,14 @@
1
+ from typing import TYPE_CHECKING, TypeVar
2
+
3
+ if TYPE_CHECKING:
4
+ from vellum.workflows import BaseWorkflow
5
+ from vellum.workflows.inputs import BaseInputs
6
+ from vellum.workflows.nodes import BaseNode
7
+ from vellum.workflows.outputs import BaseOutputs
8
+ from vellum.workflows.state import BaseState
9
+
10
+ NodeType = TypeVar("NodeType", bound="BaseNode")
11
+ StateType = TypeVar("StateType", bound="BaseState")
12
+ WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
13
+ WorkflowInputsType = TypeVar("WorkflowInputsType", bound="BaseInputs")
14
+ OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
@@ -0,0 +1,39 @@
1
+ from collections import deque
2
+ from typing import Deque, Generic, Iterable, List, TypeVar
3
+
4
+ _T = TypeVar("_T")
5
+
6
+
7
+ class Stack(Generic[_T]):
8
+ def __init__(self) -> None:
9
+ self._items: Deque[_T] = deque()
10
+
11
+ def push(self, item: _T) -> None:
12
+ self._items.append(item)
13
+
14
+ def extend(self, items: Iterable[_T]) -> None:
15
+ item_list = list(items)
16
+ for item in item_list[::-1]:
17
+ self._items.append(item)
18
+
19
+ def pop(self) -> _T:
20
+ if not self.is_empty():
21
+ return self._items.pop()
22
+ raise IndexError("pop from empty stack")
23
+
24
+ def peek(self) -> _T:
25
+ if not self.is_empty():
26
+ return self._items[-1]
27
+ raise IndexError("peek from empty stack")
28
+
29
+ def is_empty(self) -> bool:
30
+ return len(self._items) == 0
31
+
32
+ def size(self) -> int:
33
+ return len(self._items)
34
+
35
+ def __repr__(self) -> str:
36
+ return f"Stack({self.dump()})"
37
+
38
+ def dump(self) -> List[_T]:
39
+ return [item for item in self._items][::-1]
File without changes