vellum-ai 0.9.16rc2__py3-none-any.whl → 0.9.16rc4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (245) hide show
  1. vellum/plugins/__init__.py +0 -0
  2. vellum/plugins/pydantic.py +74 -0
  3. vellum/plugins/utils.py +19 -0
  4. vellum/plugins/vellum_mypy.py +639 -3
  5. vellum/workflows/README.md +90 -0
  6. vellum/workflows/__init__.py +5 -0
  7. vellum/workflows/constants.py +43 -0
  8. vellum/workflows/descriptors/__init__.py +0 -0
  9. vellum/workflows/descriptors/base.py +339 -0
  10. vellum/workflows/descriptors/tests/test_utils.py +83 -0
  11. vellum/workflows/descriptors/utils.py +90 -0
  12. vellum/workflows/edges/__init__.py +5 -0
  13. vellum/workflows/edges/edge.py +23 -0
  14. vellum/workflows/emitters/__init__.py +5 -0
  15. vellum/workflows/emitters/base.py +14 -0
  16. vellum/workflows/environment/__init__.py +5 -0
  17. vellum/workflows/environment/environment.py +7 -0
  18. vellum/workflows/errors/__init__.py +6 -0
  19. vellum/workflows/errors/types.py +20 -0
  20. vellum/workflows/events/__init__.py +31 -0
  21. vellum/workflows/events/node.py +125 -0
  22. vellum/workflows/events/tests/__init__.py +0 -0
  23. vellum/workflows/events/tests/test_event.py +216 -0
  24. vellum/workflows/events/types.py +52 -0
  25. vellum/workflows/events/utils.py +5 -0
  26. vellum/workflows/events/workflow.py +139 -0
  27. vellum/workflows/exceptions.py +15 -0
  28. vellum/workflows/expressions/__init__.py +0 -0
  29. vellum/workflows/expressions/accessor.py +52 -0
  30. vellum/workflows/expressions/and_.py +32 -0
  31. vellum/workflows/expressions/begins_with.py +31 -0
  32. vellum/workflows/expressions/between.py +38 -0
  33. vellum/workflows/expressions/coalesce_expression.py +41 -0
  34. vellum/workflows/expressions/contains.py +30 -0
  35. vellum/workflows/expressions/does_not_begin_with.py +31 -0
  36. vellum/workflows/expressions/does_not_contain.py +30 -0
  37. vellum/workflows/expressions/does_not_end_with.py +31 -0
  38. vellum/workflows/expressions/does_not_equal.py +25 -0
  39. vellum/workflows/expressions/ends_with.py +31 -0
  40. vellum/workflows/expressions/equals.py +25 -0
  41. vellum/workflows/expressions/greater_than.py +33 -0
  42. vellum/workflows/expressions/greater_than_or_equal_to.py +33 -0
  43. vellum/workflows/expressions/in_.py +31 -0
  44. vellum/workflows/expressions/is_blank.py +24 -0
  45. vellum/workflows/expressions/is_not_blank.py +24 -0
  46. vellum/workflows/expressions/is_not_null.py +21 -0
  47. vellum/workflows/expressions/is_not_undefined.py +22 -0
  48. vellum/workflows/expressions/is_null.py +21 -0
  49. vellum/workflows/expressions/is_undefined.py +22 -0
  50. vellum/workflows/expressions/less_than.py +33 -0
  51. vellum/workflows/expressions/less_than_or_equal_to.py +33 -0
  52. vellum/workflows/expressions/not_between.py +38 -0
  53. vellum/workflows/expressions/not_in.py +31 -0
  54. vellum/workflows/expressions/or_.py +32 -0
  55. vellum/workflows/graph/__init__.py +3 -0
  56. vellum/workflows/graph/graph.py +131 -0
  57. vellum/workflows/graph/tests/__init__.py +0 -0
  58. vellum/workflows/graph/tests/test_graph.py +437 -0
  59. vellum/workflows/inputs/__init__.py +5 -0
  60. vellum/workflows/inputs/base.py +55 -0
  61. vellum/workflows/logging.py +14 -0
  62. vellum/workflows/nodes/__init__.py +46 -0
  63. vellum/workflows/nodes/bases/__init__.py +7 -0
  64. vellum/workflows/nodes/bases/base.py +332 -0
  65. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +5 -0
  66. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +10 -0
  67. vellum/workflows/nodes/bases/tests/__init__.py +0 -0
  68. vellum/workflows/nodes/bases/tests/test_base_node.py +125 -0
  69. vellum/workflows/nodes/core/__init__.py +16 -0
  70. vellum/workflows/nodes/core/error_node/__init__.py +5 -0
  71. vellum/workflows/nodes/core/error_node/node.py +26 -0
  72. vellum/workflows/nodes/core/inline_subworkflow_node/__init__.py +5 -0
  73. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +73 -0
  74. vellum/workflows/nodes/core/map_node/__init__.py +5 -0
  75. vellum/workflows/nodes/core/map_node/node.py +147 -0
  76. vellum/workflows/nodes/core/map_node/tests/__init__.py +0 -0
  77. vellum/workflows/nodes/core/map_node/tests/test_node.py +65 -0
  78. vellum/workflows/nodes/core/retry_node/__init__.py +5 -0
  79. vellum/workflows/nodes/core/retry_node/node.py +106 -0
  80. vellum/workflows/nodes/core/retry_node/tests/__init__.py +0 -0
  81. vellum/workflows/nodes/core/retry_node/tests/test_node.py +93 -0
  82. vellum/workflows/nodes/core/templating_node/__init__.py +5 -0
  83. vellum/workflows/nodes/core/templating_node/custom_filters.py +12 -0
  84. vellum/workflows/nodes/core/templating_node/exceptions.py +2 -0
  85. vellum/workflows/nodes/core/templating_node/node.py +123 -0
  86. vellum/workflows/nodes/core/templating_node/render.py +55 -0
  87. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  88. vellum/workflows/nodes/core/try_node/__init__.py +5 -0
  89. vellum/workflows/nodes/core/try_node/node.py +110 -0
  90. vellum/workflows/nodes/core/try_node/tests/__init__.py +0 -0
  91. vellum/workflows/nodes/core/try_node/tests/test_node.py +82 -0
  92. vellum/workflows/nodes/displayable/__init__.py +31 -0
  93. vellum/workflows/nodes/displayable/api_node/__init__.py +5 -0
  94. vellum/workflows/nodes/displayable/api_node/node.py +44 -0
  95. vellum/workflows/nodes/displayable/bases/__init__.py +11 -0
  96. vellum/workflows/nodes/displayable/bases/api_node/__init__.py +5 -0
  97. vellum/workflows/nodes/displayable/bases/api_node/node.py +70 -0
  98. vellum/workflows/nodes/displayable/bases/base_prompt_node/__init__.py +5 -0
  99. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +60 -0
  100. vellum/workflows/nodes/displayable/bases/inline_prompt_node/__init__.py +5 -0
  101. vellum/workflows/nodes/displayable/bases/inline_prompt_node/constants.py +13 -0
  102. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +118 -0
  103. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +98 -0
  104. vellum/workflows/nodes/displayable/bases/search_node.py +90 -0
  105. vellum/workflows/nodes/displayable/code_execution_node/__init__.py +5 -0
  106. vellum/workflows/nodes/displayable/code_execution_node/node.py +197 -0
  107. vellum/workflows/nodes/displayable/code_execution_node/tests/__init__.py +0 -0
  108. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/__init__.py +0 -0
  109. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/main.py +3 -0
  110. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +111 -0
  111. vellum/workflows/nodes/displayable/code_execution_node/utils.py +10 -0
  112. vellum/workflows/nodes/displayable/conditional_node/__init__.py +5 -0
  113. vellum/workflows/nodes/displayable/conditional_node/node.py +25 -0
  114. vellum/workflows/nodes/displayable/final_output_node/__init__.py +5 -0
  115. vellum/workflows/nodes/displayable/final_output_node/node.py +43 -0
  116. vellum/workflows/nodes/displayable/guardrail_node/__init__.py +5 -0
  117. vellum/workflows/nodes/displayable/guardrail_node/node.py +97 -0
  118. vellum/workflows/nodes/displayable/inline_prompt_node/__init__.py +5 -0
  119. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +41 -0
  120. vellum/workflows/nodes/displayable/merge_node/__init__.py +5 -0
  121. vellum/workflows/nodes/displayable/merge_node/node.py +10 -0
  122. vellum/workflows/nodes/displayable/prompt_deployment_node/__init__.py +5 -0
  123. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +45 -0
  124. vellum/workflows/nodes/displayable/search_node/__init__.py +5 -0
  125. vellum/workflows/nodes/displayable/search_node/node.py +26 -0
  126. vellum/workflows/nodes/displayable/subworkflow_deployment_node/__init__.py +5 -0
  127. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +156 -0
  128. vellum/workflows/nodes/displayable/tests/__init__.py +0 -0
  129. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +148 -0
  130. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +134 -0
  131. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +80 -0
  132. vellum/workflows/nodes/utils.py +27 -0
  133. vellum/workflows/outputs/__init__.py +6 -0
  134. vellum/workflows/outputs/base.py +196 -0
  135. vellum/workflows/ports/__init__.py +7 -0
  136. vellum/workflows/ports/node_ports.py +75 -0
  137. vellum/workflows/ports/port.py +75 -0
  138. vellum/workflows/ports/utils.py +40 -0
  139. vellum/workflows/references/__init__.py +17 -0
  140. vellum/workflows/references/environment_variable.py +20 -0
  141. vellum/workflows/references/execution_count.py +20 -0
  142. vellum/workflows/references/external_input.py +49 -0
  143. vellum/workflows/references/input.py +7 -0
  144. vellum/workflows/references/lazy.py +55 -0
  145. vellum/workflows/references/node.py +43 -0
  146. vellum/workflows/references/output.py +78 -0
  147. vellum/workflows/references/state_value.py +23 -0
  148. vellum/workflows/references/vellum_secret.py +15 -0
  149. vellum/workflows/references/workflow_input.py +41 -0
  150. vellum/workflows/resolvers/__init__.py +5 -0
  151. vellum/workflows/resolvers/base.py +15 -0
  152. vellum/workflows/runner/__init__.py +5 -0
  153. vellum/workflows/runner/runner.py +588 -0
  154. vellum/workflows/runner/types.py +18 -0
  155. vellum/workflows/state/__init__.py +5 -0
  156. vellum/workflows/state/base.py +327 -0
  157. vellum/workflows/state/context.py +18 -0
  158. vellum/workflows/state/encoder.py +57 -0
  159. vellum/workflows/state/store.py +28 -0
  160. vellum/workflows/state/tests/__init__.py +0 -0
  161. vellum/workflows/state/tests/test_state.py +113 -0
  162. vellum/workflows/types/__init__.py +0 -0
  163. vellum/workflows/types/core.py +91 -0
  164. vellum/workflows/types/generics.py +14 -0
  165. vellum/workflows/types/stack.py +39 -0
  166. vellum/workflows/types/tests/__init__.py +0 -0
  167. vellum/workflows/types/tests/test_utils.py +76 -0
  168. vellum/workflows/types/utils.py +164 -0
  169. vellum/workflows/utils/__init__.py +0 -0
  170. vellum/workflows/utils/names.py +13 -0
  171. vellum/workflows/utils/tests/__init__.py +0 -0
  172. vellum/workflows/utils/tests/test_names.py +15 -0
  173. vellum/workflows/utils/tests/test_vellum_variables.py +25 -0
  174. vellum/workflows/utils/vellum_variables.py +81 -0
  175. vellum/workflows/vellum_client.py +18 -0
  176. vellum/workflows/workflows/__init__.py +5 -0
  177. vellum/workflows/workflows/base.py +365 -0
  178. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/METADATA +2 -1
  179. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/RECORD +245 -7
  180. vellum_cli/__init__.py +72 -0
  181. vellum_cli/aliased_group.py +103 -0
  182. vellum_cli/config.py +96 -0
  183. vellum_cli/image_push.py +112 -0
  184. vellum_cli/logger.py +36 -0
  185. vellum_cli/pull.py +73 -0
  186. vellum_cli/push.py +121 -0
  187. vellum_cli/tests/test_config.py +100 -0
  188. vellum_cli/tests/test_pull.py +152 -0
  189. vellum_ee/workflows/__init__.py +0 -0
  190. vellum_ee/workflows/display/__init__.py +0 -0
  191. vellum_ee/workflows/display/base.py +73 -0
  192. vellum_ee/workflows/display/nodes/__init__.py +4 -0
  193. vellum_ee/workflows/display/nodes/base_node_display.py +116 -0
  194. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +36 -0
  195. vellum_ee/workflows/display/nodes/get_node_display_class.py +25 -0
  196. vellum_ee/workflows/display/nodes/tests/__init__.py +0 -0
  197. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +47 -0
  198. vellum_ee/workflows/display/nodes/types.py +18 -0
  199. vellum_ee/workflows/display/nodes/utils.py +33 -0
  200. vellum_ee/workflows/display/nodes/vellum/__init__.py +32 -0
  201. vellum_ee/workflows/display/nodes/vellum/api_node.py +205 -0
  202. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +71 -0
  203. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +217 -0
  204. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +61 -0
  205. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +49 -0
  206. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +170 -0
  207. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +99 -0
  208. vellum_ee/workflows/display/nodes/vellum/map_node.py +100 -0
  209. vellum_ee/workflows/display/nodes/vellum/merge_node.py +48 -0
  210. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +68 -0
  211. vellum_ee/workflows/display/nodes/vellum/search_node.py +193 -0
  212. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +58 -0
  213. vellum_ee/workflows/display/nodes/vellum/templating_node.py +67 -0
  214. vellum_ee/workflows/display/nodes/vellum/tests/__init__.py +0 -0
  215. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +106 -0
  216. vellum_ee/workflows/display/nodes/vellum/try_node.py +38 -0
  217. vellum_ee/workflows/display/nodes/vellum/utils.py +76 -0
  218. vellum_ee/workflows/display/tests/__init__.py +0 -0
  219. vellum_ee/workflows/display/tests/workflow_serialization/__init__.py +0 -0
  220. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +426 -0
  221. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +607 -0
  222. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +1175 -0
  223. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +235 -0
  224. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +511 -0
  225. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +372 -0
  226. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +272 -0
  227. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +289 -0
  228. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +354 -0
  229. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +123 -0
  230. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +84 -0
  231. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +233 -0
  232. vellum_ee/workflows/display/types.py +46 -0
  233. vellum_ee/workflows/display/utils/__init__.py +0 -0
  234. vellum_ee/workflows/display/utils/tests/__init__.py +0 -0
  235. vellum_ee/workflows/display/utils/tests/test_uuids.py +16 -0
  236. vellum_ee/workflows/display/utils/uuids.py +24 -0
  237. vellum_ee/workflows/display/utils/vellum.py +121 -0
  238. vellum_ee/workflows/display/vellum.py +357 -0
  239. vellum_ee/workflows/display/workflows/__init__.py +5 -0
  240. vellum_ee/workflows/display/workflows/base_workflow_display.py +302 -0
  241. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +32 -0
  242. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +386 -0
  243. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/LICENSE +0 -0
  244. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/WHEEL +0 -0
  245. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,327 @@
1
+ from collections import defaultdict
2
+ from copy import deepcopy
3
+ from dataclasses import field
4
+ from datetime import datetime
5
+ from queue import Queue
6
+ from threading import Lock
7
+ from uuid import UUID, uuid4
8
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
9
+ from typing_extensions import dataclass_transform
10
+
11
+ from pydantic import GetCoreSchemaHandler, field_serializer
12
+ from pydantic_core import core_schema
13
+
14
+ from vellum.core.pydantic_utilities import UniversalBaseModel
15
+
16
+ from vellum.workflows.constants import UNDEF
17
+ from vellum.workflows.inputs.base import BaseInputs
18
+ from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
19
+ from vellum.workflows.types.generics import StateType
20
+ from vellum.workflows.types.stack import Stack
21
+ from vellum.workflows.types.utils import datetime_now, deepcopy_with_exclusions, get_class_by_qualname, infer_types
22
+
23
+ if TYPE_CHECKING:
24
+ from vellum.workflows.nodes.bases import BaseNode
25
+
26
+
27
+ class _Snapshottable:
28
+ _snapshot_callback: Callable[[], None]
29
+
30
+ def __deepcopy__(self, memo: Any) -> "_Snapshottable":
31
+ return deepcopy_with_exclusions(
32
+ self,
33
+ memo=memo,
34
+ exclusions={
35
+ "_snapshot_callback": self._snapshot_callback,
36
+ },
37
+ )
38
+
39
+
40
+ @dataclass_transform(kw_only_default=True)
41
+ class _BaseStateMeta(type):
42
+ def __getattribute__(cls, name: str) -> Any:
43
+ if not name.startswith("_"):
44
+ instance = vars(cls).get(name)
45
+ types = infer_types(cls, name)
46
+ return StateValueReference(name=name, types=types, instance=instance)
47
+
48
+ return super().__getattribute__(name)
49
+
50
+
51
+ class _SnapshottableDict(dict, _Snapshottable):
52
+ def __setitem__(self, key: Any, value: Any) -> None:
53
+ super().__setitem__(key, value)
54
+ self._snapshot_callback()
55
+
56
+
57
+ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> Any:
58
+ """
59
+ Edits any value to make it snapshottable on edit. Made as a separate function from `BaseState` to
60
+ avoid namespace conflicts with subclasses.
61
+ """
62
+ if isinstance(value, _Snapshottable):
63
+ return value
64
+
65
+ if isinstance(value, dict):
66
+ snapshottable_dict = _SnapshottableDict(value)
67
+ snapshottable_dict._snapshot_callback = snapshot_callback
68
+ return snapshottable_dict
69
+
70
+ return value
71
+
72
+
73
+ class NodeExecutionCache:
74
+ _node_execution_ids: Dict[Type["BaseNode"], Stack[UUID]]
75
+ _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
76
+ _dependencies_invoked: Dict[str, Set[str]]
77
+
78
+ def __init__(
79
+ self,
80
+ dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
81
+ node_execution_ids: Optional[Dict[str, Sequence[str]]] = None,
82
+ node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
83
+ ) -> None:
84
+ self._dependencies_invoked = defaultdict(set)
85
+ self._node_execution_ids = defaultdict(Stack[UUID])
86
+ self._node_executions_initiated = defaultdict(set)
87
+
88
+ for node, dependencies in (dependencies_invoked or {}).items():
89
+ self._dependencies_invoked[node].update(dependencies)
90
+
91
+ for node, execution_ids in (node_execution_ids or {}).items():
92
+ node_class = get_class_by_qualname(node)
93
+ self._node_execution_ids[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
94
+
95
+ for node, execution_ids in (node_executions_initiated or {}).items():
96
+ node_class = get_class_by_qualname(node)
97
+ self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
98
+
99
+ @property
100
+ def dependencies_invoked(self) -> Dict[str, Set[str]]:
101
+ return self._dependencies_invoked
102
+
103
+ def is_node_initiated(self, node: Type["BaseNode"]) -> bool:
104
+ return node in self._node_executions_initiated and len(self._node_executions_initiated[node]) > 0
105
+
106
+ def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
107
+ self._node_executions_initiated[node].add(execution_id)
108
+
109
+ def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
110
+ self._node_executions_initiated[node].remove(execution_id)
111
+ self._node_execution_ids[node].push(execution_id)
112
+
113
+ def get_execution_count(self, node: Type["BaseNode"]) -> int:
114
+ return self._node_execution_ids[node].size()
115
+
116
+ def dump(self) -> Dict[str, Any]:
117
+ return {
118
+ "dependencies_invoked": {
119
+ node: list(dependencies) for node, dependencies in self._dependencies_invoked.items()
120
+ },
121
+ "node_executions_initiated": {
122
+ str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
123
+ },
124
+ "node_execution_ids": {
125
+ str(node): execution_ids.dump() for node, execution_ids in self._node_execution_ids.items()
126
+ },
127
+ }
128
+
129
+ @classmethod
130
+ def __get_pydantic_core_schema__(
131
+ cls, source_type: Type[Any], handler: GetCoreSchemaHandler
132
+ ) -> core_schema.CoreSchema:
133
+ return core_schema.is_instance_schema(cls)
134
+
135
+
136
+ def uuid4_default_factory() -> UUID:
137
+ """
138
+ Allows us to mock the uuid4 for testing.
139
+ """
140
+ return uuid4()
141
+
142
+
143
+ def default_datetime_factory() -> datetime:
144
+ """
145
+ Makes it possible to mock the datetime factory for testing.
146
+ """
147
+
148
+ return datetime_now()
149
+
150
+
151
+ class StateMeta(UniversalBaseModel):
152
+ id: UUID = field(default_factory=uuid4_default_factory)
153
+ trace_id: UUID = field(default_factory=uuid4_default_factory)
154
+ span_id: UUID = field(default_factory=uuid4_default_factory)
155
+ updated_ts: datetime = field(default_factory=default_datetime_factory)
156
+ workflow_inputs: BaseInputs = field(default_factory=BaseInputs)
157
+ external_inputs: Dict[ExternalInputReference, Any] = field(default_factory=dict)
158
+ node_outputs: Dict[OutputReference, Any] = field(default_factory=dict)
159
+ node_execution_cache: NodeExecutionCache = field(default_factory=NodeExecutionCache)
160
+ parent: Optional["BaseState"] = None
161
+ is_terminated: Optional[bool] = None
162
+ __snapshot_callback__: Optional[Callable[[], None]] = field(init=False, default=None)
163
+
164
+ def model_post_init(self, context: Any) -> None:
165
+ if self.parent:
166
+ self.trace_id = self.parent.meta.trace_id
167
+ self.__snapshot_callback__ = None
168
+
169
+ def add_snapshot_callback(self, callback: Callable[[], None]) -> None:
170
+ self.node_outputs = _make_snapshottable(self.node_outputs, callback)
171
+ self.__snapshot_callback__ = callback
172
+
173
+ def __setattr__(self, name: str, value: Any) -> None:
174
+ if name.startswith("__") or name == "updated_ts":
175
+ super().__setattr__(name, value)
176
+ return
177
+
178
+ super().__setattr__(name, value)
179
+ if callable(self.__snapshot_callback__):
180
+ self.__snapshot_callback__()
181
+
182
+ @field_serializer("node_outputs")
183
+ def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
184
+ return {str(descriptor): value for descriptor, value in node_outputs.items()}
185
+
186
+ @field_serializer("external_inputs")
187
+ def serialize_external_inputs(
188
+ self, external_inputs: Dict[ExternalInputReference, Any], _info: Any
189
+ ) -> Dict[str, Any]:
190
+ return {str(descriptor): value for descriptor, value in external_inputs.items()}
191
+
192
+ def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "StateMeta":
193
+ if not memo:
194
+ memo = {}
195
+
196
+ new_node_outputs = {
197
+ descriptor: value if isinstance(value, Queue) else deepcopy(value, memo)
198
+ for descriptor, value in self.node_outputs.items()
199
+ }
200
+
201
+ memo[id(self.node_outputs)] = new_node_outputs
202
+ memo[id(self.__snapshot_callback__)] = None
203
+
204
+ return super().__deepcopy__(memo)
205
+
206
+
207
+ class BaseState(metaclass=_BaseStateMeta):
208
+ meta: StateMeta = field(init=False)
209
+
210
+ __lock__: Lock = field(init=False)
211
+ __is_initializing__: bool = field(init=False)
212
+ __snapshot_callback__: Callable[["BaseState"], None] = field(init=False)
213
+
214
+ def __init__(self, meta: Optional[StateMeta] = None, **kwargs: Any) -> None:
215
+ self.__is_initializing__ = True
216
+ self.__snapshot_callback__ = lambda state: None
217
+ self.__lock__ = Lock()
218
+
219
+ self.meta = meta or StateMeta()
220
+ self.meta.add_snapshot_callback(self.__snapshot__)
221
+
222
+ # Make all class attribute values snapshottable
223
+ for name, value in self.__class__.__dict__.items():
224
+ if not name.startswith("_") and name != "meta":
225
+ # Bypass __is_initializing__ instead of `setattr`
226
+ snapshottable_value = _make_snapshottable(value, self.__snapshot__)
227
+ super().__setattr__(name, snapshottable_value)
228
+
229
+ for name, value in kwargs.items():
230
+ setattr(self, name, value)
231
+
232
+ self.__is_initializing__ = False
233
+
234
+ def __deepcopy__(self, memo: Any) -> "BaseState":
235
+ new_state = deepcopy_with_exclusions(
236
+ self,
237
+ exclusions={
238
+ "__lock__": Lock(),
239
+ },
240
+ memo=memo,
241
+ )
242
+ new_state.meta.add_snapshot_callback(new_state.__snapshot__)
243
+ return new_state
244
+
245
+ def __repr__(self) -> str:
246
+ values = "\n".join(
247
+ [f" {key}={value}" for key, value in vars(self).items() if not key.startswith("_") and key != "meta"]
248
+ )
249
+ node_outputs = "\n".join([f" {key}={value}" for key, value in self.meta.node_outputs.items()])
250
+ return f"""\
251
+ {self.__class__.__name__}:
252
+ {values}
253
+ meta:
254
+ id={self.meta.id}
255
+ is_terminated={self.meta.is_terminated}
256
+ updated_ts={self.meta.updated_ts}
257
+ node_outputs:{' Empty' if not node_outputs else ''}
258
+ {node_outputs}
259
+ """
260
+
261
+ def __iter__(self) -> Iterator[Tuple[Any, Any]]:
262
+ """
263
+ Returns an iterator treating all state keys as (key, value) items, allowing consumers to call `dict()`
264
+ on an instance of this class.
265
+ """
266
+
267
+ # If the user sets a default value on state (e.g. something = "foo"), it's not on `instance_attributes` below.
268
+ # So we need to include class_attributes here just in case
269
+ class_attributes = {key: value for key, value in self.__class__.__dict__.items() if not key.startswith("_")}
270
+ instance_attributes = {key: value for key, value in self.__dict__.items() if not key.startswith("__")}
271
+
272
+ all_attributes = {**class_attributes, **instance_attributes}
273
+ items = [(key, value) for key, value in all_attributes.items() if key not in ["_lock"]]
274
+ return iter(items)
275
+
276
+ def __getitem__(self, key: str) -> Any:
277
+ return self.__dict__[key]
278
+
279
+ def __setattr__(self, name: str, value: Any) -> None:
280
+ if name.startswith("_") or self.__is_initializing__:
281
+ super().__setattr__(name, value)
282
+ return
283
+
284
+ snapshottable_value = _make_snapshottable(value, self.__snapshot__)
285
+ super().__setattr__(name, snapshottable_value)
286
+ self.meta.updated_ts = datetime_now()
287
+ self.__snapshot__()
288
+
289
+ def __add__(self, other: StateType) -> StateType:
290
+ """
291
+ Handles merging two states together, preferring the latest state by updated_ts for any given node output.
292
+ """
293
+
294
+ if not isinstance(other, type(self)):
295
+ raise TypeError(f"Cannot add {type(other).__name__} to {type(self).__name__}]")
296
+
297
+ latest_state = self if self.meta.updated_ts >= other.meta.updated_ts else other
298
+ oldest_state = other if latest_state == self else self
299
+
300
+ for descriptor, value in oldest_state.meta.node_outputs.items():
301
+ if descriptor not in latest_state.meta.node_outputs:
302
+ latest_state.meta.node_outputs[descriptor] = value
303
+
304
+ for key, value in oldest_state:
305
+ if not isinstance(key, str):
306
+ continue
307
+
308
+ if key.startswith("_"):
309
+ continue
310
+
311
+ if getattr(latest_state, key, UNDEF) == UNDEF:
312
+ setattr(latest_state, key, value)
313
+
314
+ return cast(StateType, latest_state)
315
+
316
+ def __snapshot__(self) -> None:
317
+ """
318
+ Snapshots the current state to the workflow emitter. The invoked callback is overridden by the
319
+ workflow runner.
320
+ """
321
+ self.__snapshot_callback__(deepcopy(self))
322
+
323
+ @classmethod
324
+ def __get_pydantic_core_schema__(
325
+ cls, source_type: Type[Any], handler: GetCoreSchemaHandler
326
+ ) -> core_schema.CoreSchema:
327
+ return core_schema.is_instance_schema(cls)
@@ -0,0 +1,18 @@
1
+ from functools import cached_property
2
+ from typing import Optional
3
+
4
+ from vellum import Vellum
5
+
6
+ from vellum.workflows.vellum_client import create_vellum_client
7
+
8
+
9
+ class WorkflowContext:
10
+ def __init__(self, _vellum_client: Optional[Vellum] = None):
11
+ self._vellum_client = _vellum_client
12
+
13
+ @cached_property
14
+ def vellum_client(self) -> Vellum:
15
+ if self._vellum_client:
16
+ return self._vellum_client
17
+
18
+ return create_vellum_client()
@@ -0,0 +1,57 @@
1
+ from dataclasses import asdict, is_dataclass
2
+ from datetime import datetime
3
+ import enum
4
+ from json import JSONEncoder
5
+ from uuid import UUID
6
+ from typing import Any, Callable, Dict, Type
7
+
8
+ from pydantic import BaseModel
9
+
10
+ from vellum.workflows.inputs.base import BaseInputs
11
+ from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
12
+ from vellum.workflows.state.base import BaseState, NodeExecutionCache
13
+
14
+
15
+ class DefaultStateEncoder(JSONEncoder):
16
+ encoders: Dict[Type, Callable] = {}
17
+
18
+ def default(self, obj: Any) -> Any:
19
+ if isinstance(obj, BaseState):
20
+ return dict(obj)
21
+
22
+ if isinstance(obj, (BaseInputs, BaseOutputs)):
23
+ return {descriptor.name: value for descriptor, value in obj}
24
+
25
+ if isinstance(obj, BaseOutput):
26
+ return obj.serialize()
27
+
28
+ if isinstance(obj, NodeExecutionCache):
29
+ return obj.dump()
30
+
31
+ if isinstance(obj, UUID):
32
+ return str(obj)
33
+
34
+ if isinstance(obj, set):
35
+ return list(obj)
36
+
37
+ if isinstance(obj, BaseModel):
38
+ return obj.model_dump()
39
+
40
+ if isinstance(obj, datetime):
41
+ return obj.isoformat()
42
+
43
+ if isinstance(obj, enum.Enum):
44
+ return obj.value
45
+
46
+ if is_dataclass(obj):
47
+ # Technically, obj is DataclassInstance | type[DataclassInstance], but asdict expects a DataclassInstance
48
+ # in practice, we only ever pass the former
49
+ return asdict(obj) # type: ignore[call-overload]
50
+
51
+ if isinstance(obj, type):
52
+ return str(obj)
53
+
54
+ if obj.__class__ in self.encoders:
55
+ return self.encoders[obj.__class__](obj)
56
+
57
+ return super().default(obj)
@@ -0,0 +1,28 @@
1
+ from typing import Iterator, List
2
+
3
+ from vellum.workflows.events.workflow import WorkflowEvent
4
+ from vellum.workflows.state.base import BaseState
5
+
6
+
7
+ class Store:
8
+ def __init__(self) -> None:
9
+ self._events: List[WorkflowEvent] = []
10
+ self._state_snapshots: List[BaseState] = []
11
+
12
+ def append_event(self, event: WorkflowEvent) -> None:
13
+ self._events.append(event)
14
+
15
+ def append_state_snapshot(self, state: BaseState) -> None:
16
+ self._state_snapshots.append(state)
17
+
18
+ def clear(self) -> None:
19
+ self._events = []
20
+ self._state_snapshots = []
21
+
22
+ @property
23
+ def events(self) -> Iterator[WorkflowEvent]:
24
+ return iter(self._events)
25
+
26
+ @property
27
+ def state_snapshots(self) -> Iterator[BaseState]:
28
+ return iter(self._state_snapshots)
File without changes
@@ -0,0 +1,113 @@
1
+ from collections import defaultdict
2
+ from copy import deepcopy
3
+ import json
4
+ from typing import Dict
5
+
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.outputs.base import BaseOutputs
8
+ from vellum.workflows.state.base import BaseState
9
+ from vellum.workflows.state.encoder import DefaultStateEncoder
10
+
11
+ snapshot_count: Dict[int, int] = defaultdict(int)
12
+
13
+
14
+ class MockState(BaseState):
15
+ foo: str
16
+ nested_dict: Dict[str, int] = {}
17
+
18
+ def __snapshot__(self) -> None:
19
+ global snapshot_count
20
+ snapshot_count[id(self)] += 1
21
+
22
+
23
+ class MockNode(BaseNode):
24
+ class Outputs(BaseOutputs):
25
+ baz: str
26
+
27
+
28
+ def test_state_snapshot__node_attribute_edit():
29
+ # GIVEN an initial state instance
30
+ state = MockState(foo="bar")
31
+ assert snapshot_count[id(state)] == 0
32
+
33
+ # WHEN we edit an attribute
34
+ state.foo = "baz"
35
+
36
+ # THEN the snapshot is emitted
37
+ assert snapshot_count[id(state)] == 1
38
+
39
+
40
+ def test_state_snapshot__node_output_edit():
41
+ # GIVEN an initial state instance
42
+ state = MockState(foo="bar")
43
+ assert snapshot_count[id(state)] == 0
44
+
45
+ # WHEN we add a Node Output to state
46
+ for output in MockNode.Outputs:
47
+ state.meta.node_outputs[output] = "hello"
48
+
49
+ # THEN the snapshot is emitted
50
+ assert snapshot_count[id(state)] == 1
51
+
52
+
53
+ def test_state_snapshot__nested_dictionary_edit():
54
+ # GIVEN an initial state instance
55
+ state = MockState(foo="bar")
56
+ assert snapshot_count[id(state)] == 0
57
+
58
+ # WHEN we edit a nested dictionary
59
+ state.nested_dict["hello"] = 1
60
+
61
+ # THEN the snapshot is emitted
62
+ assert snapshot_count[id(state)] == 1
63
+
64
+
65
+ def test_state_deepcopy():
66
+ # GIVEN an initial state instance
67
+ state = MockState(foo="bar")
68
+
69
+ # AND we add a Node Output to state
70
+ state.meta.node_outputs[MockNode.Outputs.baz] = "hello"
71
+
72
+ # WHEN we deepcopy the state
73
+ deepcopied_state = deepcopy(state)
74
+
75
+ # THEN node outputs are deepcopied
76
+ assert deepcopied_state.meta.node_outputs == state.meta.node_outputs
77
+
78
+
79
+ def test_state_deepcopy__with_node_output_updates():
80
+ # GIVEN an initial state instance
81
+ state = MockState(foo="bar")
82
+
83
+ # AND we add a Node Output to state
84
+ state.meta.node_outputs[MockNode.Outputs.baz] = "hello"
85
+
86
+ # AND we deepcopy the state
87
+ deepcopied_state = deepcopy(state)
88
+
89
+ # AND we update the original state
90
+ state.meta.node_outputs[MockNode.Outputs.baz] = "world"
91
+
92
+ # THEN the copied state is not updated
93
+ assert deepcopied_state.meta.node_outputs[MockNode.Outputs.baz] == "hello"
94
+
95
+ # AND the original state has had the correct number of snapshots
96
+ assert snapshot_count[id(state)] == 2
97
+
98
+ # AND the copied state has had the correct number of snapshots
99
+ assert snapshot_count[id(deepcopied_state)] == 0
100
+
101
+
102
+ def test_state_json_serialization__with_node_output_updates():
103
+ # GIVEN an initial state instance
104
+ state = MockState(foo="bar")
105
+
106
+ # AND we add a Node Output to state
107
+ state.meta.node_outputs[MockNode.Outputs.baz] = "hello"
108
+
109
+ # WHEN we serialize the state
110
+ json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
111
+
112
+ # THEN the state is serialized correctly
113
+ assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": "hello"}
File without changes
@@ -0,0 +1,91 @@
1
+ from enum import Enum
2
+ from typing import ( # type: ignore[attr-defined]
3
+ Dict,
4
+ List,
5
+ Union,
6
+ _GenericAlias,
7
+ _SpecialGenericAlias,
8
+ _UnionGenericAlias,
9
+ )
10
+
11
+ from vellum import (
12
+ ChatMessage,
13
+ FunctionCall,
14
+ FunctionCallRequest,
15
+ SearchResult,
16
+ SearchResultRequest,
17
+ VellumAudio,
18
+ VellumAudioRequest,
19
+ VellumError,
20
+ VellumErrorRequest,
21
+ VellumImage,
22
+ VellumImageRequest,
23
+ VellumValue,
24
+ VellumValueRequest,
25
+ )
26
+
27
+ JsonArray = List["Json"]
28
+ JsonObject = Dict[str, "Json"]
29
+ Json = Union[None, bool, int, float, str, JsonArray, JsonObject]
30
+
31
+ # Unions and Generics inherit from `_GenericAlias` instead of `type`
32
+ # In future versions of python, we'll see `_UnionGenericAlias`
33
+ UnderGenericAlias = _GenericAlias
34
+ SpecialGenericAlias = _SpecialGenericAlias
35
+ UnionGenericAlias = _UnionGenericAlias
36
+
37
+
38
+ class VellumSecret:
39
+ name: str
40
+
41
+ def __init__(self, name: str):
42
+ self.name = name
43
+
44
+
45
+ VellumValuePrimitive = Union[
46
+ # String inputs
47
+ str,
48
+ # Chat history inputs
49
+ List[ChatMessage],
50
+ List[ChatMessage],
51
+ # Search results inputs
52
+ List[SearchResultRequest],
53
+ List[SearchResult],
54
+ # JSON inputs
55
+ Json,
56
+ # Number inputs
57
+ float,
58
+ # Function Call Inputs
59
+ FunctionCall,
60
+ FunctionCallRequest,
61
+ # Error Inputs
62
+ VellumError,
63
+ VellumErrorRequest,
64
+ # Array Inputs
65
+ List[VellumValueRequest],
66
+ List[VellumValue],
67
+ # Image Inputs
68
+ VellumImage,
69
+ VellumImageRequest,
70
+ # Audio Inputs
71
+ VellumAudio,
72
+ VellumAudioRequest,
73
+ # Vellum Secrets
74
+ VellumSecret,
75
+ ]
76
+
77
+ EntityInputsInterface = Dict[
78
+ str,
79
+ VellumValuePrimitive,
80
+ ]
81
+
82
+
83
+ class MergeBehavior(Enum):
84
+ AWAIT_ALL = "AWAIT_ALL"
85
+ AWAIT_ANY = "AWAIT_ANY"
86
+
87
+
88
+ class ConditionType(Enum):
89
+ IF = "IF"
90
+ ELIF = "ELIF"
91
+ ELSE = "ELSE"
@@ -0,0 +1,14 @@
1
+ from typing import TYPE_CHECKING, TypeVar
2
+
3
+ if TYPE_CHECKING:
4
+ from vellum.workflows import BaseWorkflow
5
+ from vellum.workflows.inputs import BaseInputs
6
+ from vellum.workflows.nodes import BaseNode
7
+ from vellum.workflows.outputs import BaseOutputs
8
+ from vellum.workflows.state import BaseState
9
+
10
+ NodeType = TypeVar("NodeType", bound="BaseNode")
11
+ StateType = TypeVar("StateType", bound="BaseState")
12
+ WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
13
+ WorkflowInputsType = TypeVar("WorkflowInputsType", bound="BaseInputs")
14
+ OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
@@ -0,0 +1,39 @@
1
+ from collections import deque
2
+ from typing import Deque, Generic, Iterable, List, TypeVar
3
+
4
+ _T = TypeVar("_T")
5
+
6
+
7
+ class Stack(Generic[_T]):
8
+ def __init__(self) -> None:
9
+ self._items: Deque[_T] = deque()
10
+
11
+ def push(self, item: _T) -> None:
12
+ self._items.append(item)
13
+
14
+ def extend(self, items: Iterable[_T]) -> None:
15
+ item_list = list(items)
16
+ for item in item_list[::-1]:
17
+ self._items.append(item)
18
+
19
+ def pop(self) -> _T:
20
+ if not self.is_empty():
21
+ return self._items.pop()
22
+ raise IndexError("pop from empty stack")
23
+
24
+ def peek(self) -> _T:
25
+ if not self.is_empty():
26
+ return self._items[-1]
27
+ raise IndexError("peek from empty stack")
28
+
29
+ def is_empty(self) -> bool:
30
+ return len(self._items) == 0
31
+
32
+ def size(self) -> int:
33
+ return len(self._items)
34
+
35
+ def __repr__(self) -> str:
36
+ return f"Stack({self.dump()})"
37
+
38
+ def dump(self) -> List[_T]:
39
+ return [item for item in self._items][::-1]
File without changes