vellum-ai 0.9.16rc2__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (245) hide show
  1. vellum/plugins/__init__.py +0 -0
  2. vellum/plugins/pydantic.py +74 -0
  3. vellum/plugins/utils.py +19 -0
  4. vellum/plugins/vellum_mypy.py +639 -3
  5. vellum/workflows/README.md +90 -0
  6. vellum/workflows/__init__.py +5 -0
  7. vellum/workflows/constants.py +43 -0
  8. vellum/workflows/descriptors/__init__.py +0 -0
  9. vellum/workflows/descriptors/base.py +339 -0
  10. vellum/workflows/descriptors/tests/test_utils.py +83 -0
  11. vellum/workflows/descriptors/utils.py +90 -0
  12. vellum/workflows/edges/__init__.py +5 -0
  13. vellum/workflows/edges/edge.py +23 -0
  14. vellum/workflows/emitters/__init__.py +5 -0
  15. vellum/workflows/emitters/base.py +14 -0
  16. vellum/workflows/environment/__init__.py +5 -0
  17. vellum/workflows/environment/environment.py +7 -0
  18. vellum/workflows/errors/__init__.py +6 -0
  19. vellum/workflows/errors/types.py +20 -0
  20. vellum/workflows/events/__init__.py +31 -0
  21. vellum/workflows/events/node.py +125 -0
  22. vellum/workflows/events/tests/__init__.py +0 -0
  23. vellum/workflows/events/tests/test_event.py +216 -0
  24. vellum/workflows/events/types.py +52 -0
  25. vellum/workflows/events/utils.py +5 -0
  26. vellum/workflows/events/workflow.py +139 -0
  27. vellum/workflows/exceptions.py +15 -0
  28. vellum/workflows/expressions/__init__.py +0 -0
  29. vellum/workflows/expressions/accessor.py +52 -0
  30. vellum/workflows/expressions/and_.py +32 -0
  31. vellum/workflows/expressions/begins_with.py +31 -0
  32. vellum/workflows/expressions/between.py +38 -0
  33. vellum/workflows/expressions/coalesce_expression.py +41 -0
  34. vellum/workflows/expressions/contains.py +30 -0
  35. vellum/workflows/expressions/does_not_begin_with.py +31 -0
  36. vellum/workflows/expressions/does_not_contain.py +30 -0
  37. vellum/workflows/expressions/does_not_end_with.py +31 -0
  38. vellum/workflows/expressions/does_not_equal.py +25 -0
  39. vellum/workflows/expressions/ends_with.py +31 -0
  40. vellum/workflows/expressions/equals.py +25 -0
  41. vellum/workflows/expressions/greater_than.py +33 -0
  42. vellum/workflows/expressions/greater_than_or_equal_to.py +33 -0
  43. vellum/workflows/expressions/in_.py +31 -0
  44. vellum/workflows/expressions/is_blank.py +24 -0
  45. vellum/workflows/expressions/is_not_blank.py +24 -0
  46. vellum/workflows/expressions/is_not_null.py +21 -0
  47. vellum/workflows/expressions/is_not_undefined.py +22 -0
  48. vellum/workflows/expressions/is_null.py +21 -0
  49. vellum/workflows/expressions/is_undefined.py +22 -0
  50. vellum/workflows/expressions/less_than.py +33 -0
  51. vellum/workflows/expressions/less_than_or_equal_to.py +33 -0
  52. vellum/workflows/expressions/not_between.py +38 -0
  53. vellum/workflows/expressions/not_in.py +31 -0
  54. vellum/workflows/expressions/or_.py +32 -0
  55. vellum/workflows/graph/__init__.py +3 -0
  56. vellum/workflows/graph/graph.py +131 -0
  57. vellum/workflows/graph/tests/__init__.py +0 -0
  58. vellum/workflows/graph/tests/test_graph.py +437 -0
  59. vellum/workflows/inputs/__init__.py +5 -0
  60. vellum/workflows/inputs/base.py +55 -0
  61. vellum/workflows/logging.py +14 -0
  62. vellum/workflows/nodes/__init__.py +46 -0
  63. vellum/workflows/nodes/bases/__init__.py +7 -0
  64. vellum/workflows/nodes/bases/base.py +332 -0
  65. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +5 -0
  66. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +10 -0
  67. vellum/workflows/nodes/bases/tests/__init__.py +0 -0
  68. vellum/workflows/nodes/bases/tests/test_base_node.py +125 -0
  69. vellum/workflows/nodes/core/__init__.py +16 -0
  70. vellum/workflows/nodes/core/error_node/__init__.py +5 -0
  71. vellum/workflows/nodes/core/error_node/node.py +26 -0
  72. vellum/workflows/nodes/core/inline_subworkflow_node/__init__.py +5 -0
  73. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +73 -0
  74. vellum/workflows/nodes/core/map_node/__init__.py +5 -0
  75. vellum/workflows/nodes/core/map_node/node.py +147 -0
  76. vellum/workflows/nodes/core/map_node/tests/__init__.py +0 -0
  77. vellum/workflows/nodes/core/map_node/tests/test_node.py +65 -0
  78. vellum/workflows/nodes/core/retry_node/__init__.py +5 -0
  79. vellum/workflows/nodes/core/retry_node/node.py +106 -0
  80. vellum/workflows/nodes/core/retry_node/tests/__init__.py +0 -0
  81. vellum/workflows/nodes/core/retry_node/tests/test_node.py +93 -0
  82. vellum/workflows/nodes/core/templating_node/__init__.py +5 -0
  83. vellum/workflows/nodes/core/templating_node/custom_filters.py +12 -0
  84. vellum/workflows/nodes/core/templating_node/exceptions.py +2 -0
  85. vellum/workflows/nodes/core/templating_node/node.py +123 -0
  86. vellum/workflows/nodes/core/templating_node/render.py +55 -0
  87. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  88. vellum/workflows/nodes/core/try_node/__init__.py +5 -0
  89. vellum/workflows/nodes/core/try_node/node.py +110 -0
  90. vellum/workflows/nodes/core/try_node/tests/__init__.py +0 -0
  91. vellum/workflows/nodes/core/try_node/tests/test_node.py +82 -0
  92. vellum/workflows/nodes/displayable/__init__.py +31 -0
  93. vellum/workflows/nodes/displayable/api_node/__init__.py +5 -0
  94. vellum/workflows/nodes/displayable/api_node/node.py +44 -0
  95. vellum/workflows/nodes/displayable/bases/__init__.py +11 -0
  96. vellum/workflows/nodes/displayable/bases/api_node/__init__.py +5 -0
  97. vellum/workflows/nodes/displayable/bases/api_node/node.py +70 -0
  98. vellum/workflows/nodes/displayable/bases/base_prompt_node/__init__.py +5 -0
  99. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +60 -0
  100. vellum/workflows/nodes/displayable/bases/inline_prompt_node/__init__.py +5 -0
  101. vellum/workflows/nodes/displayable/bases/inline_prompt_node/constants.py +13 -0
  102. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +118 -0
  103. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +98 -0
  104. vellum/workflows/nodes/displayable/bases/search_node.py +90 -0
  105. vellum/workflows/nodes/displayable/code_execution_node/__init__.py +5 -0
  106. vellum/workflows/nodes/displayable/code_execution_node/node.py +197 -0
  107. vellum/workflows/nodes/displayable/code_execution_node/tests/__init__.py +0 -0
  108. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/__init__.py +0 -0
  109. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/main.py +3 -0
  110. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +111 -0
  111. vellum/workflows/nodes/displayable/code_execution_node/utils.py +10 -0
  112. vellum/workflows/nodes/displayable/conditional_node/__init__.py +5 -0
  113. vellum/workflows/nodes/displayable/conditional_node/node.py +25 -0
  114. vellum/workflows/nodes/displayable/final_output_node/__init__.py +5 -0
  115. vellum/workflows/nodes/displayable/final_output_node/node.py +43 -0
  116. vellum/workflows/nodes/displayable/guardrail_node/__init__.py +5 -0
  117. vellum/workflows/nodes/displayable/guardrail_node/node.py +97 -0
  118. vellum/workflows/nodes/displayable/inline_prompt_node/__init__.py +5 -0
  119. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +41 -0
  120. vellum/workflows/nodes/displayable/merge_node/__init__.py +5 -0
  121. vellum/workflows/nodes/displayable/merge_node/node.py +10 -0
  122. vellum/workflows/nodes/displayable/prompt_deployment_node/__init__.py +5 -0
  123. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +45 -0
  124. vellum/workflows/nodes/displayable/search_node/__init__.py +5 -0
  125. vellum/workflows/nodes/displayable/search_node/node.py +26 -0
  126. vellum/workflows/nodes/displayable/subworkflow_deployment_node/__init__.py +5 -0
  127. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +156 -0
  128. vellum/workflows/nodes/displayable/tests/__init__.py +0 -0
  129. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +148 -0
  130. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +134 -0
  131. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +80 -0
  132. vellum/workflows/nodes/utils.py +27 -0
  133. vellum/workflows/outputs/__init__.py +6 -0
  134. vellum/workflows/outputs/base.py +196 -0
  135. vellum/workflows/ports/__init__.py +7 -0
  136. vellum/workflows/ports/node_ports.py +75 -0
  137. vellum/workflows/ports/port.py +75 -0
  138. vellum/workflows/ports/utils.py +40 -0
  139. vellum/workflows/references/__init__.py +17 -0
  140. vellum/workflows/references/environment_variable.py +20 -0
  141. vellum/workflows/references/execution_count.py +20 -0
  142. vellum/workflows/references/external_input.py +49 -0
  143. vellum/workflows/references/input.py +7 -0
  144. vellum/workflows/references/lazy.py +55 -0
  145. vellum/workflows/references/node.py +43 -0
  146. vellum/workflows/references/output.py +78 -0
  147. vellum/workflows/references/state_value.py +23 -0
  148. vellum/workflows/references/vellum_secret.py +15 -0
  149. vellum/workflows/references/workflow_input.py +41 -0
  150. vellum/workflows/resolvers/__init__.py +5 -0
  151. vellum/workflows/resolvers/base.py +15 -0
  152. vellum/workflows/runner/__init__.py +5 -0
  153. vellum/workflows/runner/runner.py +588 -0
  154. vellum/workflows/runner/types.py +18 -0
  155. vellum/workflows/state/__init__.py +5 -0
  156. vellum/workflows/state/base.py +327 -0
  157. vellum/workflows/state/context.py +18 -0
  158. vellum/workflows/state/encoder.py +57 -0
  159. vellum/workflows/state/store.py +28 -0
  160. vellum/workflows/state/tests/__init__.py +0 -0
  161. vellum/workflows/state/tests/test_state.py +113 -0
  162. vellum/workflows/types/__init__.py +0 -0
  163. vellum/workflows/types/core.py +91 -0
  164. vellum/workflows/types/generics.py +14 -0
  165. vellum/workflows/types/stack.py +39 -0
  166. vellum/workflows/types/tests/__init__.py +0 -0
  167. vellum/workflows/types/tests/test_utils.py +76 -0
  168. vellum/workflows/types/utils.py +164 -0
  169. vellum/workflows/utils/__init__.py +0 -0
  170. vellum/workflows/utils/names.py +13 -0
  171. vellum/workflows/utils/tests/__init__.py +0 -0
  172. vellum/workflows/utils/tests/test_names.py +15 -0
  173. vellum/workflows/utils/tests/test_vellum_variables.py +25 -0
  174. vellum/workflows/utils/vellum_variables.py +81 -0
  175. vellum/workflows/vellum_client.py +18 -0
  176. vellum/workflows/workflows/__init__.py +5 -0
  177. vellum/workflows/workflows/base.py +365 -0
  178. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/METADATA +2 -1
  179. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/RECORD +245 -7
  180. vellum_cli/__init__.py +72 -0
  181. vellum_cli/aliased_group.py +103 -0
  182. vellum_cli/config.py +96 -0
  183. vellum_cli/image_push.py +112 -0
  184. vellum_cli/logger.py +36 -0
  185. vellum_cli/pull.py +73 -0
  186. vellum_cli/push.py +121 -0
  187. vellum_cli/tests/test_config.py +100 -0
  188. vellum_cli/tests/test_pull.py +152 -0
  189. vellum_ee/workflows/__init__.py +0 -0
  190. vellum_ee/workflows/display/__init__.py +0 -0
  191. vellum_ee/workflows/display/base.py +73 -0
  192. vellum_ee/workflows/display/nodes/__init__.py +4 -0
  193. vellum_ee/workflows/display/nodes/base_node_display.py +116 -0
  194. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +36 -0
  195. vellum_ee/workflows/display/nodes/get_node_display_class.py +25 -0
  196. vellum_ee/workflows/display/nodes/tests/__init__.py +0 -0
  197. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +47 -0
  198. vellum_ee/workflows/display/nodes/types.py +18 -0
  199. vellum_ee/workflows/display/nodes/utils.py +33 -0
  200. vellum_ee/workflows/display/nodes/vellum/__init__.py +32 -0
  201. vellum_ee/workflows/display/nodes/vellum/api_node.py +205 -0
  202. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +71 -0
  203. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +217 -0
  204. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +61 -0
  205. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +49 -0
  206. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +170 -0
  207. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +99 -0
  208. vellum_ee/workflows/display/nodes/vellum/map_node.py +100 -0
  209. vellum_ee/workflows/display/nodes/vellum/merge_node.py +48 -0
  210. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +68 -0
  211. vellum_ee/workflows/display/nodes/vellum/search_node.py +193 -0
  212. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +58 -0
  213. vellum_ee/workflows/display/nodes/vellum/templating_node.py +67 -0
  214. vellum_ee/workflows/display/nodes/vellum/tests/__init__.py +0 -0
  215. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +106 -0
  216. vellum_ee/workflows/display/nodes/vellum/try_node.py +38 -0
  217. vellum_ee/workflows/display/nodes/vellum/utils.py +76 -0
  218. vellum_ee/workflows/display/tests/__init__.py +0 -0
  219. vellum_ee/workflows/display/tests/workflow_serialization/__init__.py +0 -0
  220. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +426 -0
  221. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +607 -0
  222. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +1175 -0
  223. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +235 -0
  224. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +511 -0
  225. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +372 -0
  226. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +272 -0
  227. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +289 -0
  228. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +354 -0
  229. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +123 -0
  230. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +84 -0
  231. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +233 -0
  232. vellum_ee/workflows/display/types.py +46 -0
  233. vellum_ee/workflows/display/utils/__init__.py +0 -0
  234. vellum_ee/workflows/display/utils/tests/__init__.py +0 -0
  235. vellum_ee/workflows/display/utils/tests/test_uuids.py +16 -0
  236. vellum_ee/workflows/display/utils/uuids.py +24 -0
  237. vellum_ee/workflows/display/utils/vellum.py +121 -0
  238. vellum_ee/workflows/display/vellum.py +357 -0
  239. vellum_ee/workflows/display/workflows/__init__.py +5 -0
  240. vellum_ee/workflows/display/workflows/base_workflow_display.py +302 -0
  241. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +32 -0
  242. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +386 -0
  243. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/LICENSE +0 -0
  244. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/WHEEL +0 -0
  245. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,147 @@
1
+ from collections import defaultdict
2
+ from queue import Empty, Queue
3
+ from threading import Thread
4
+ from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, overload
5
+
6
+ from vellum.workflows.descriptors.base import BaseDescriptor
7
+ from vellum.workflows.errors.types import VellumErrorCode
8
+ from vellum.workflows.exceptions import NodeException
9
+ from vellum.workflows.inputs.base import BaseInputs
10
+ from vellum.workflows.nodes.bases import BaseNode
11
+ from vellum.workflows.outputs import BaseOutputs
12
+ from vellum.workflows.state.base import BaseState
13
+ from vellum.workflows.types.generics import NodeType, StateType
14
+
15
+ if TYPE_CHECKING:
16
+ from vellum.workflows import BaseWorkflow
17
+ from vellum.workflows.events.workflow import WorkflowEvent
18
+
19
+ MapNodeItemType = TypeVar("MapNodeItemType")
20
+
21
+
22
+ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
23
+ """
24
+ Used to map over a list of items and execute a Subworkflow on each iteration.
25
+
26
+ items: List[MapNodeItemType] - The items to map over
27
+ subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute on each iteration
28
+ concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
29
+ """
30
+
31
+ items: List[MapNodeItemType]
32
+ subworkflow: Type["BaseWorkflow"]
33
+ concurrency: Optional[int] = None
34
+
35
+ class Outputs(BaseOutputs):
36
+ mapped_items: list
37
+
38
+ class SubworkflowInputs(BaseInputs):
39
+ # TODO: Both type: ignore's below are believed to be incorrect and both have the following error:
40
+ # Type variable "workflows.nodes.map_node.map_node.MapNodeItemType" is unbound
41
+ # https://app.shortcut.com/vellum/story/4118
42
+
43
+ item: MapNodeItemType # type: ignore[valid-type]
44
+ index: int
45
+ all_items: List[MapNodeItemType] # type: ignore[valid-type]
46
+
47
+ def run(self) -> Outputs:
48
+ mapped_items: Dict[str, List] = defaultdict(list)
49
+ for output_descripter in self.subworkflow.Outputs:
50
+ mapped_items[output_descripter.name] = [None] * len(self.items)
51
+
52
+ self._event_queue: Queue[Tuple[int, WorkflowEvent]] = Queue()
53
+ fulfilled_iterations: List[bool] = []
54
+ for index, item in enumerate(self.items):
55
+ fulfilled_iterations.append(False)
56
+ thread = Thread(target=self._run_subworkflow, kwargs={"item": item, "index": index})
57
+ thread.start()
58
+
59
+ try:
60
+ # We should consolidate this logic with the logic workflow runner uses
61
+ # https://app.shortcut.com/vellum/story/4736
62
+ while map_node_event := self._event_queue.get():
63
+ index = map_node_event[0]
64
+ terminal_event = map_node_event[1]
65
+
66
+ if terminal_event.name == "workflow.execution.fulfilled":
67
+ workflow_output_vars = vars(terminal_event.outputs)
68
+
69
+ for output_name in workflow_output_vars:
70
+ output_mapped_items = mapped_items[output_name]
71
+ output_mapped_items[index] = workflow_output_vars[output_name]
72
+
73
+ fulfilled_iterations[index] = True
74
+ if all(fulfilled_iterations):
75
+ break
76
+ elif terminal_event.name == "workflow.execution.paused":
77
+ raise NodeException(
78
+ code=VellumErrorCode.INVALID_OUTPUTS,
79
+ message=f"Subworkflow unexpectedly paused on iteration {index}",
80
+ )
81
+ elif terminal_event.name == "workflow.execution.rejected":
82
+ raise NodeException(
83
+ f"Subworkflow failed on iteration {index} with error: {terminal_event.error.message}",
84
+ code=terminal_event.error.code,
85
+ )
86
+ except Empty:
87
+ pass
88
+
89
+ return self.Outputs(**mapped_items)
90
+
91
+ def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
92
+ subworkflow = self.subworkflow(parent_state=self.state)
93
+ events = subworkflow.stream(inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items))
94
+
95
+ for event in events:
96
+ self._event_queue.put((index, event))
97
+
98
+ @overload
99
+ @classmethod
100
+ def wrap(cls, items: List[MapNodeItemType]) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
101
+
102
+ # TODO: We should be able to do this overload automatically as we do with node attributes
103
+ # https://app.shortcut.com/vellum/story/5289
104
+ @overload
105
+ @classmethod
106
+ def wrap(
107
+ cls, items: BaseDescriptor[List[MapNodeItemType]]
108
+ ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
109
+
110
+ @classmethod
111
+ def wrap(
112
+ cls, items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]]
113
+ ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]:
114
+ _items = items
115
+
116
+ def decorator(inner_cls: Type[NodeType]) -> Type["MapNode[StateType, MapNodeItemType]"]:
117
+ # Investigate how to use dependency injection to avoid circular imports
118
+ # https://app.shortcut.com/vellum/story/4116
119
+ from vellum.workflows import BaseWorkflow
120
+
121
+ class Subworkflow(BaseWorkflow[MapNode.SubworkflowInputs, BaseState]):
122
+ graph = inner_cls
123
+
124
+ # mypy is wrong here, this works and is defined
125
+ class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
126
+ pass
127
+
128
+ class WrappedNodeOutputs(BaseOutputs):
129
+ pass
130
+
131
+ WrappedNodeOutputs.__annotations__ = {
132
+ # TODO: We'll need to infer the type T of Subworkflow.Outputs[name] so we could do List[T] here
133
+ # https://app.shortcut.com/vellum/story/4119
134
+ descriptor.name: List
135
+ for descriptor in inner_cls.Outputs
136
+ }
137
+
138
+ class WrappedNode(MapNode[StateType, MapNodeItemType]):
139
+ items = _items
140
+ subworkflow = Subworkflow
141
+
142
+ class Outputs(WrappedNodeOutputs):
143
+ pass
144
+
145
+ return WrappedNode
146
+
147
+ return decorator
File without changes
@@ -0,0 +1,65 @@
1
+ import time
2
+
3
+ from vellum.workflows.inputs.base import BaseInputs
4
+ from vellum.workflows.nodes.bases import BaseNode
5
+ from vellum.workflows.nodes.core.map_node.node import MapNode
6
+ from vellum.workflows.outputs.base import BaseOutputs
7
+ from vellum.workflows.state.base import BaseState, StateMeta
8
+
9
+
10
+ def test_map_node__use_parent_inputs_and_state():
11
+ # GIVEN a parent workflow Inputs and State
12
+ class Inputs(BaseInputs):
13
+ foo: str
14
+
15
+ class State(BaseState):
16
+ bar: str
17
+
18
+ # AND a map node that is configured to use the parent's inputs and state
19
+ @MapNode.wrap(items=[1, 2, 3])
20
+ class TestNode(BaseNode):
21
+ item = MapNode.SubworkflowInputs.item
22
+ foo = Inputs.foo
23
+ bar = State.bar
24
+
25
+ class Outputs(BaseOutputs):
26
+ value: str
27
+
28
+ def run(self) -> Outputs:
29
+ return self.Outputs(value=f"{self.foo} {self.bar} {self.item}")
30
+
31
+ # WHEN the node is run
32
+ node = TestNode(
33
+ state=State(
34
+ bar="bar",
35
+ meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
36
+ )
37
+ )
38
+ outputs = node.run()
39
+
40
+ # THEN the data is used successfully
41
+ assert outputs.value == ["foo bar 1", "foo bar 2", "foo bar 3"]
42
+
43
+
44
+ def test_map_node__use_parallelism():
45
+ # GIVEN a map node that is configured to use the parent's inputs and state
46
+ @MapNode.wrap(items=list(range(10)))
47
+ class TestNode(BaseNode):
48
+ item = MapNode.SubworkflowInputs.item
49
+
50
+ class Outputs(BaseOutputs):
51
+ value: int
52
+
53
+ def run(self) -> Outputs:
54
+ time.sleep(0.03)
55
+ return self.Outputs(value=self.item + 1)
56
+
57
+ # WHEN the node is run
58
+ node = TestNode(state=BaseState())
59
+ start_ts = time.time_ns()
60
+ node.run()
61
+ end_ts = time.time_ns()
62
+
63
+ # THEN the node should have ran in parallel
64
+ run_time = (end_ts - start_ts) / 10**9
65
+ assert run_time < 0.1
@@ -0,0 +1,5 @@
1
+ from .node import RetryNode
2
+
3
+ __all__ = [
4
+ "RetryNode",
5
+ ]
@@ -0,0 +1,106 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type
2
+
3
+ from vellum.workflows.errors.types import VellumErrorCode
4
+ from vellum.workflows.exceptions import NodeException
5
+ from vellum.workflows.inputs.base import BaseInputs
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.nodes.bases.base import BaseNodeMeta
8
+ from vellum.workflows.state.base import BaseState
9
+ from vellum.workflows.types.generics import StateType
10
+
11
+ if TYPE_CHECKING:
12
+ from vellum.workflows import BaseWorkflow
13
+
14
+
15
+ class _RetryNodeMeta(BaseNodeMeta):
16
+ @property
17
+ def _localns(cls) -> Dict[str, Any]:
18
+ return {
19
+ **super()._localns,
20
+ "SubworkflowInputs": getattr(cls, "SubworkflowInputs"),
21
+ }
22
+
23
+
24
+ class RetryNode(BaseNode[StateType], Generic[StateType], metaclass=_RetryNodeMeta):
25
+ """
26
+ Used to retry a Subworkflow a specified number of times.
27
+
28
+ max_attempts: int - The maximum number of attempts to retry the Subworkflow
29
+ retry_on_error_code: Optional[VellumErrorCode] = None - The error code to retry on
30
+ subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute
31
+ """
32
+
33
+ max_attempts: int
34
+ retry_on_error_code: Optional[VellumErrorCode] = None
35
+ subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"]
36
+
37
+ class SubworkflowInputs(BaseInputs):
38
+ attempt_number: int
39
+
40
+ def run(self) -> BaseNode.Outputs:
41
+ last_exception = Exception("max_attempts must be greater than 0")
42
+ for index in range(self.max_attempts):
43
+ attempt_number = index + 1
44
+ subworkflow = self.subworkflow(
45
+ parent_state=self.state,
46
+ )
47
+ terminal_event = subworkflow.run(
48
+ inputs=self.SubworkflowInputs(attempt_number=attempt_number),
49
+ )
50
+ if terminal_event.name == "workflow.execution.fulfilled":
51
+ node_outputs = self.Outputs()
52
+ workflow_output_vars = vars(terminal_event.outputs)
53
+
54
+ for output_name in workflow_output_vars:
55
+ setattr(node_outputs, output_name, workflow_output_vars[output_name])
56
+
57
+ return node_outputs
58
+ elif terminal_event.name == "workflow.execution.paused":
59
+ last_exception = NodeException(
60
+ code=VellumErrorCode.INVALID_OUTPUTS,
61
+ message=f"Subworkflow unexpectedly paused on attempt {attempt_number}",
62
+ )
63
+ break
64
+ elif self.retry_on_error_code and self.retry_on_error_code != terminal_event.error.code:
65
+ last_exception = NodeException(
66
+ code=VellumErrorCode.INVALID_OUTPUTS,
67
+ message=f"""Unexpected rejection on attempt {attempt_number}: {terminal_event.error.code.value}.
68
+ Message: {terminal_event.error.message}""",
69
+ )
70
+ break
71
+ else:
72
+ last_exception = Exception(terminal_event.error.message)
73
+
74
+ raise last_exception
75
+
76
+ @classmethod
77
+ def wrap(
78
+ cls, max_attempts: int, retry_on_error_code: Optional[VellumErrorCode] = None
79
+ ) -> Callable[..., Type["RetryNode"]]:
80
+ _max_attempts = max_attempts
81
+ _retry_on_error_code = retry_on_error_code
82
+
83
+ def decorator(inner_cls: Type[BaseNode]) -> Type["RetryNode"]:
84
+ # Investigate how to use dependency injection to avoid circular imports
85
+ # https://app.shortcut.com/vellum/story/4116
86
+ from vellum.workflows import BaseWorkflow
87
+
88
+ class Subworkflow(BaseWorkflow[RetryNode.SubworkflowInputs, BaseState]):
89
+ graph = inner_cls
90
+
91
+ # mypy is wrong here, this works and is defined
92
+ class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
93
+ pass
94
+
95
+ class WrappedNode(RetryNode[StateType]):
96
+ max_attempts = _max_attempts
97
+ retry_on_error_code = _retry_on_error_code
98
+
99
+ subworkflow = Subworkflow
100
+
101
+ class Outputs(Subworkflow.Outputs):
102
+ pass
103
+
104
+ return WrappedNode
105
+
106
+ return decorator
@@ -0,0 +1,93 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.errors.types import VellumErrorCode
4
+ from vellum.workflows.exceptions import NodeException
5
+ from vellum.workflows.inputs.base import BaseInputs
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.nodes.core.retry_node.node import RetryNode
8
+ from vellum.workflows.outputs import BaseOutputs
9
+ from vellum.workflows.state.base import BaseState, StateMeta
10
+
11
+
12
+ def test_retry_node__retry_on_error_code__successfully_retried():
13
+ # GIVEN a retry node that is configured to retry on PROVIDER_ERROR
14
+ @RetryNode.wrap(max_attempts=3, retry_on_error_code=VellumErrorCode.PROVIDER_ERROR)
15
+ class TestNode(BaseNode):
16
+ attempt_number = RetryNode.SubworkflowInputs.attempt_number
17
+
18
+ class Outputs(BaseOutputs):
19
+ execution_count: int
20
+
21
+ def run(self) -> Outputs:
22
+ if self.attempt_number < 3:
23
+ raise NodeException(message="This will be retried", code=VellumErrorCode.PROVIDER_ERROR)
24
+
25
+ return self.Outputs(execution_count=self.attempt_number)
26
+
27
+ # WHEN the node is run and throws a PROVIDER_ERROR
28
+ node = TestNode(state=BaseState())
29
+ outputs = node.run()
30
+
31
+ # THEN the exception is retried
32
+ assert outputs.execution_count == 3
33
+
34
+
35
+ def test_retry_node__retry_on_error_code__missed():
36
+ # GIVEN a retry node that is configured to retry on PROVIDER_ERROR
37
+ @RetryNode.wrap(max_attempts=3, retry_on_error_code=VellumErrorCode.PROVIDER_ERROR)
38
+ class TestNode(BaseNode):
39
+ attempt_number = RetryNode.SubworkflowInputs.attempt_number
40
+
41
+ class Outputs(BaseOutputs):
42
+ execution_count: int
43
+
44
+ def run(self) -> Outputs:
45
+ if self.attempt_number < 3:
46
+ raise Exception("This will not be retried")
47
+
48
+ return self.Outputs(execution_count=self.attempt_number)
49
+
50
+ # WHEN the node is run and throws a different exception
51
+ node = TestNode(state=BaseState())
52
+ with pytest.raises(NodeException) as exc_info:
53
+ node.run()
54
+
55
+ # THEN the exception is not retried
56
+ assert (
57
+ exc_info.value.message
58
+ == "Unexpected rejection on attempt 1: INTERNAL_ERROR.\nMessage: This will not be retried"
59
+ )
60
+ assert exc_info.value.code == VellumErrorCode.INVALID_OUTPUTS
61
+
62
+
63
+ def test_retry_node__use_parent_inputs_and_state():
64
+ # GIVEN a parent workflow Inputs and State
65
+ class Inputs(BaseInputs):
66
+ foo: str
67
+
68
+ class State(BaseState):
69
+ bar: str
70
+
71
+ # AND a retry node that uses the parent's inputs and state
72
+ @RetryNode.wrap(max_attempts=3, retry_on_error_code=VellumErrorCode.PROVIDER_ERROR)
73
+ class TestNode(BaseNode):
74
+ foo = Inputs.foo
75
+ bar = State.bar
76
+
77
+ class Outputs(BaseOutputs):
78
+ value: str
79
+
80
+ def run(self) -> Outputs:
81
+ return self.Outputs(value=f"{self.foo} {self.bar}")
82
+
83
+ # WHEN the node is run
84
+ node = TestNode(
85
+ state=State(
86
+ bar="bar",
87
+ meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
88
+ )
89
+ )
90
+ outputs = node.run()
91
+
92
+ # THEN the data is used successfully
93
+ assert outputs.value == "foo bar"
@@ -0,0 +1,5 @@
1
+ from .node import TemplatingNode
2
+
3
+ __all__ = [
4
+ "TemplatingNode",
5
+ ]
@@ -0,0 +1,12 @@
1
+ import json
2
+ from typing import Union
3
+
4
+
5
+ def is_valid_json_string(value: Union[str, bytes]) -> bool:
6
+ """Determines whether the given value is a valid JSON string."""
7
+
8
+ try:
9
+ json.loads(value)
10
+ except ValueError:
11
+ return False
12
+ return True
@@ -0,0 +1,2 @@
1
+ class JinjaTemplateError(Exception):
2
+ pass
@@ -0,0 +1,123 @@
1
+ import datetime
2
+ import itertools
3
+ import json
4
+ import random
5
+ import re
6
+ from typing import Any, Callable, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, Union, get_args
7
+
8
+ import dateutil.parser
9
+ import pydash
10
+ import pytz
11
+ import yaml
12
+
13
+ from vellum.workflows.errors import VellumErrorCode
14
+ from vellum.workflows.exceptions import NodeException
15
+ from vellum.workflows.nodes.bases import BaseNode
16
+ from vellum.workflows.nodes.bases.base import BaseNodeMeta
17
+ from vellum.workflows.nodes.core.templating_node.custom_filters import is_valid_json_string
18
+ from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
19
+ from vellum.workflows.nodes.core.templating_node.render import render_sandboxed_jinja_template
20
+ from vellum.workflows.types.core import EntityInputsInterface
21
+ from vellum.workflows.types.generics import StateType
22
+ from vellum.workflows.types.utils import get_original_base
23
+
24
+ _DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
25
+ "datetime": datetime,
26
+ "dateutil": dateutil,
27
+ "itertools": itertools,
28
+ "json": json,
29
+ "pydash": pydash,
30
+ "pytz": pytz,
31
+ "random": random,
32
+ "re": re,
33
+ "yaml": yaml,
34
+ }
35
+
36
+ _DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
37
+ "is_valid_json_string": is_valid_json_string,
38
+ }
39
+
40
+ _OutputType = TypeVar("_OutputType")
41
+
42
+
43
+ # TODO: Consolidate all dynamic output metaclasses
44
+ # https://app.shortcut.com/vellum/story/5533
45
+ class _TemplatingNodeMeta(BaseNodeMeta):
46
+ def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
47
+ parent = super().__new__(mcs, name, bases, dct)
48
+
49
+ if not isinstance(parent, _TemplatingNodeMeta):
50
+ raise ValueError("TemplatingNode must be created with the TemplatingNodeMeta metaclass")
51
+
52
+ parent.__dict__["Outputs"].__annotations__["result"] = parent.get_output_type()
53
+ return parent
54
+
55
+ def get_output_type(cls) -> Type:
56
+ original_base = get_original_base(cls)
57
+ all_args = get_args(original_base)
58
+
59
+ if len(all_args) < 2 or isinstance(all_args[1], TypeVar):
60
+ return str
61
+ else:
62
+ return all_args[1]
63
+
64
+
65
+ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metaclass=_TemplatingNodeMeta):
66
+ """Used to render a Jinja template.
67
+
68
+ Useful for lightweight data transformations and complex string templating.
69
+ """
70
+
71
+ # The Jinja template to render.
72
+ template: ClassVar[str]
73
+
74
+ # The inputs to render the template with.
75
+ inputs: ClassVar[EntityInputsInterface]
76
+
77
+ jinja_globals: Dict[str, Any] = _DEFAULT_JINJA_GLOBALS
78
+ jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = _DEFAULT_JINJA_CUSTOM_FILTERS
79
+
80
+ class Outputs(BaseNode.Outputs):
81
+ # We use our mypy plugin to override the _OutputType with the actual output type
82
+ # for downstream references to this output.
83
+ result: _OutputType # type: ignore[valid-type]
84
+
85
+ def _cast_rendered_template(self, rendered_template: str) -> Any:
86
+ original_base = get_original_base(self.__class__)
87
+ all_args = get_args(original_base)
88
+
89
+ if len(all_args) < 2 or isinstance(all_args[1], TypeVar):
90
+ output_type = str
91
+ else:
92
+ output_type = all_args[1]
93
+
94
+ if output_type is str:
95
+ return rendered_template
96
+
97
+ if output_type is float:
98
+ return float(rendered_template)
99
+
100
+ if output_type is int:
101
+ return int(rendered_template)
102
+
103
+ if output_type is bool:
104
+ return bool(rendered_template)
105
+
106
+ raise ValueError(f"Unsupported output type: {output_type}")
107
+
108
+ def run(self) -> Outputs:
109
+ rendered_template = self._render_template()
110
+ result = self._cast_rendered_template(rendered_template)
111
+
112
+ return self.Outputs(result=result)
113
+
114
+ def _render_template(self) -> str:
115
+ try:
116
+ return render_sandboxed_jinja_template(
117
+ template=self.template,
118
+ input_values=self.inputs,
119
+ jinja_custom_filters={**self.jinja_custom_filters},
120
+ jinja_globals=self.jinja_globals,
121
+ )
122
+ except JinjaTemplateError as e:
123
+ raise NodeException(message=str(e), code=VellumErrorCode.INVALID_TEMPLATE)
@@ -0,0 +1,55 @@
1
+ import json
2
+ from typing import Any, Callable, Dict, Optional, Union
3
+
4
+ from jinja2.sandbox import SandboxedEnvironment
5
+
6
+ from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
7
+ from vellum.workflows.state.encoder import DefaultStateEncoder
8
+
9
+
10
+ def finalize(obj: Any) -> str:
11
+ if isinstance(obj, dict):
12
+ return json.dumps(obj, cls=DefaultStateEncoder)
13
+
14
+ return str(obj)
15
+
16
+
17
+ def render_sandboxed_jinja_template(
18
+ *,
19
+ template: str,
20
+ input_values: Dict[str, Any],
21
+ jinja_custom_filters: Optional[Dict[str, Callable[[Union[str, bytes]], bool]]] = None,
22
+ jinja_globals: Optional[Dict[str, Any]] = None,
23
+ ) -> str:
24
+ """Render a Jinja template within a sandboxed environment."""
25
+
26
+ try:
27
+ environment = SandboxedEnvironment(
28
+ keep_trailing_newline=True,
29
+ finalize=finalize,
30
+ )
31
+
32
+ if jinja_custom_filters:
33
+ environment.filters.update(jinja_custom_filters)
34
+
35
+ jinja_template = environment.from_string(template)
36
+
37
+ if jinja_globals:
38
+ jinja_template.globals.update(jinja_globals)
39
+
40
+ rendered_template = jinja_template.render(input_values)
41
+ except json.JSONDecodeError as e:
42
+ if e.msg == "Invalid control character at":
43
+ raise JinjaTemplateError(
44
+ "Unable to render jinja template:\n"
45
+ "Cannot run json.loads() on JSON containing control characters. "
46
+ "Use json.loads(input, strict=False) instead.",
47
+ )
48
+
49
+ raise JinjaTemplateError(
50
+ f"Unable to render jinja template:\nCannot run json.loads() on invalid JSON\n{e.args[0]}"
51
+ )
52
+ except Exception as e:
53
+ raise JinjaTemplateError(f"Unable to render jinja template:\n{e.args[0]}")
54
+
55
+ return rendered_template
@@ -0,0 +1,21 @@
1
+ import json
2
+
3
+ from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
4
+
5
+
6
+ def test_templating_node__dict_output():
7
+ # GIVEN a templating node with a dict input that just returns it
8
+ class TemplateNode(TemplatingNode):
9
+ template = "{{ data }}"
10
+ inputs = {
11
+ "data": {
12
+ "key": "value",
13
+ }
14
+ }
15
+
16
+ # WHEN the node is run
17
+ node = TemplateNode()
18
+ outputs = node.run()
19
+
20
+ # THEN the output is json serializable
21
+ assert json.loads(outputs.result) == {"key": "value"}
@@ -0,0 +1,5 @@
1
+ from .node import TryNode
2
+
3
+ __all__ = [
4
+ "TryNode",
5
+ ]