vellum-ai 0.9.16rc2__py3-none-any.whl → 0.9.16rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. vellum/plugins/__init__.py +0 -0
  2. vellum/plugins/pydantic.py +74 -0
  3. vellum/plugins/utils.py +19 -0
  4. vellum/plugins/vellum_mypy.py +639 -3
  5. vellum/workflows/README.md +90 -0
  6. vellum/workflows/__init__.py +5 -0
  7. vellum/workflows/constants.py +43 -0
  8. vellum/workflows/descriptors/__init__.py +0 -0
  9. vellum/workflows/descriptors/base.py +339 -0
  10. vellum/workflows/descriptors/tests/test_utils.py +83 -0
  11. vellum/workflows/descriptors/utils.py +90 -0
  12. vellum/workflows/edges/__init__.py +5 -0
  13. vellum/workflows/edges/edge.py +23 -0
  14. vellum/workflows/emitters/__init__.py +5 -0
  15. vellum/workflows/emitters/base.py +14 -0
  16. vellum/workflows/environment/__init__.py +5 -0
  17. vellum/workflows/environment/environment.py +7 -0
  18. vellum/workflows/errors/__init__.py +6 -0
  19. vellum/workflows/errors/types.py +20 -0
  20. vellum/workflows/events/__init__.py +31 -0
  21. vellum/workflows/events/node.py +125 -0
  22. vellum/workflows/events/tests/__init__.py +0 -0
  23. vellum/workflows/events/tests/test_event.py +216 -0
  24. vellum/workflows/events/types.py +52 -0
  25. vellum/workflows/events/utils.py +5 -0
  26. vellum/workflows/events/workflow.py +139 -0
  27. vellum/workflows/exceptions.py +15 -0
  28. vellum/workflows/expressions/__init__.py +0 -0
  29. vellum/workflows/expressions/accessor.py +52 -0
  30. vellum/workflows/expressions/and_.py +32 -0
  31. vellum/workflows/expressions/begins_with.py +31 -0
  32. vellum/workflows/expressions/between.py +38 -0
  33. vellum/workflows/expressions/coalesce_expression.py +41 -0
  34. vellum/workflows/expressions/contains.py +30 -0
  35. vellum/workflows/expressions/does_not_begin_with.py +31 -0
  36. vellum/workflows/expressions/does_not_contain.py +30 -0
  37. vellum/workflows/expressions/does_not_end_with.py +31 -0
  38. vellum/workflows/expressions/does_not_equal.py +25 -0
  39. vellum/workflows/expressions/ends_with.py +31 -0
  40. vellum/workflows/expressions/equals.py +25 -0
  41. vellum/workflows/expressions/greater_than.py +33 -0
  42. vellum/workflows/expressions/greater_than_or_equal_to.py +33 -0
  43. vellum/workflows/expressions/in_.py +31 -0
  44. vellum/workflows/expressions/is_blank.py +24 -0
  45. vellum/workflows/expressions/is_not_blank.py +24 -0
  46. vellum/workflows/expressions/is_not_null.py +21 -0
  47. vellum/workflows/expressions/is_not_undefined.py +22 -0
  48. vellum/workflows/expressions/is_null.py +21 -0
  49. vellum/workflows/expressions/is_undefined.py +22 -0
  50. vellum/workflows/expressions/less_than.py +33 -0
  51. vellum/workflows/expressions/less_than_or_equal_to.py +33 -0
  52. vellum/workflows/expressions/not_between.py +38 -0
  53. vellum/workflows/expressions/not_in.py +31 -0
  54. vellum/workflows/expressions/or_.py +32 -0
  55. vellum/workflows/graph/__init__.py +3 -0
  56. vellum/workflows/graph/graph.py +131 -0
  57. vellum/workflows/graph/tests/__init__.py +0 -0
  58. vellum/workflows/graph/tests/test_graph.py +437 -0
  59. vellum/workflows/inputs/__init__.py +5 -0
  60. vellum/workflows/inputs/base.py +55 -0
  61. vellum/workflows/logging.py +14 -0
  62. vellum/workflows/nodes/__init__.py +46 -0
  63. vellum/workflows/nodes/bases/__init__.py +7 -0
  64. vellum/workflows/nodes/bases/base.py +332 -0
  65. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +5 -0
  66. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +10 -0
  67. vellum/workflows/nodes/bases/tests/__init__.py +0 -0
  68. vellum/workflows/nodes/bases/tests/test_base_node.py +125 -0
  69. vellum/workflows/nodes/core/__init__.py +16 -0
  70. vellum/workflows/nodes/core/error_node/__init__.py +5 -0
  71. vellum/workflows/nodes/core/error_node/node.py +26 -0
  72. vellum/workflows/nodes/core/inline_subworkflow_node/__init__.py +5 -0
  73. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +73 -0
  74. vellum/workflows/nodes/core/map_node/__init__.py +5 -0
  75. vellum/workflows/nodes/core/map_node/node.py +147 -0
  76. vellum/workflows/nodes/core/map_node/tests/__init__.py +0 -0
  77. vellum/workflows/nodes/core/map_node/tests/test_node.py +65 -0
  78. vellum/workflows/nodes/core/retry_node/__init__.py +5 -0
  79. vellum/workflows/nodes/core/retry_node/node.py +106 -0
  80. vellum/workflows/nodes/core/retry_node/tests/__init__.py +0 -0
  81. vellum/workflows/nodes/core/retry_node/tests/test_node.py +93 -0
  82. vellum/workflows/nodes/core/templating_node/__init__.py +5 -0
  83. vellum/workflows/nodes/core/templating_node/custom_filters.py +12 -0
  84. vellum/workflows/nodes/core/templating_node/exceptions.py +2 -0
  85. vellum/workflows/nodes/core/templating_node/node.py +123 -0
  86. vellum/workflows/nodes/core/templating_node/render.py +55 -0
  87. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  88. vellum/workflows/nodes/core/try_node/__init__.py +5 -0
  89. vellum/workflows/nodes/core/try_node/node.py +110 -0
  90. vellum/workflows/nodes/core/try_node/tests/__init__.py +0 -0
  91. vellum/workflows/nodes/core/try_node/tests/test_node.py +82 -0
  92. vellum/workflows/nodes/displayable/__init__.py +31 -0
  93. vellum/workflows/nodes/displayable/api_node/__init__.py +5 -0
  94. vellum/workflows/nodes/displayable/api_node/node.py +44 -0
  95. vellum/workflows/nodes/displayable/bases/__init__.py +11 -0
  96. vellum/workflows/nodes/displayable/bases/api_node/__init__.py +5 -0
  97. vellum/workflows/nodes/displayable/bases/api_node/node.py +70 -0
  98. vellum/workflows/nodes/displayable/bases/base_prompt_node/__init__.py +5 -0
  99. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +60 -0
  100. vellum/workflows/nodes/displayable/bases/inline_prompt_node/__init__.py +5 -0
  101. vellum/workflows/nodes/displayable/bases/inline_prompt_node/constants.py +13 -0
  102. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +118 -0
  103. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +98 -0
  104. vellum/workflows/nodes/displayable/bases/search_node.py +90 -0
  105. vellum/workflows/nodes/displayable/code_execution_node/__init__.py +5 -0
  106. vellum/workflows/nodes/displayable/code_execution_node/node.py +197 -0
  107. vellum/workflows/nodes/displayable/code_execution_node/tests/__init__.py +0 -0
  108. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/__init__.py +0 -0
  109. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/main.py +3 -0
  110. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +111 -0
  111. vellum/workflows/nodes/displayable/code_execution_node/utils.py +10 -0
  112. vellum/workflows/nodes/displayable/conditional_node/__init__.py +5 -0
  113. vellum/workflows/nodes/displayable/conditional_node/node.py +25 -0
  114. vellum/workflows/nodes/displayable/final_output_node/__init__.py +5 -0
  115. vellum/workflows/nodes/displayable/final_output_node/node.py +43 -0
  116. vellum/workflows/nodes/displayable/guardrail_node/__init__.py +5 -0
  117. vellum/workflows/nodes/displayable/guardrail_node/node.py +97 -0
  118. vellum/workflows/nodes/displayable/inline_prompt_node/__init__.py +5 -0
  119. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +41 -0
  120. vellum/workflows/nodes/displayable/merge_node/__init__.py +5 -0
  121. vellum/workflows/nodes/displayable/merge_node/node.py +10 -0
  122. vellum/workflows/nodes/displayable/prompt_deployment_node/__init__.py +5 -0
  123. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +45 -0
  124. vellum/workflows/nodes/displayable/search_node/__init__.py +5 -0
  125. vellum/workflows/nodes/displayable/search_node/node.py +26 -0
  126. vellum/workflows/nodes/displayable/subworkflow_deployment_node/__init__.py +5 -0
  127. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +156 -0
  128. vellum/workflows/nodes/displayable/tests/__init__.py +0 -0
  129. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +148 -0
  130. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +134 -0
  131. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +80 -0
  132. vellum/workflows/nodes/utils.py +27 -0
  133. vellum/workflows/outputs/__init__.py +6 -0
  134. vellum/workflows/outputs/base.py +196 -0
  135. vellum/workflows/ports/__init__.py +7 -0
  136. vellum/workflows/ports/node_ports.py +75 -0
  137. vellum/workflows/ports/port.py +75 -0
  138. vellum/workflows/ports/utils.py +40 -0
  139. vellum/workflows/references/__init__.py +17 -0
  140. vellum/workflows/references/environment_variable.py +20 -0
  141. vellum/workflows/references/execution_count.py +20 -0
  142. vellum/workflows/references/external_input.py +49 -0
  143. vellum/workflows/references/input.py +7 -0
  144. vellum/workflows/references/lazy.py +55 -0
  145. vellum/workflows/references/node.py +43 -0
  146. vellum/workflows/references/output.py +78 -0
  147. vellum/workflows/references/state_value.py +23 -0
  148. vellum/workflows/references/vellum_secret.py +15 -0
  149. vellum/workflows/references/workflow_input.py +41 -0
  150. vellum/workflows/resolvers/__init__.py +5 -0
  151. vellum/workflows/resolvers/base.py +15 -0
  152. vellum/workflows/runner/__init__.py +5 -0
  153. vellum/workflows/runner/runner.py +588 -0
  154. vellum/workflows/runner/types.py +18 -0
  155. vellum/workflows/state/__init__.py +5 -0
  156. vellum/workflows/state/base.py +327 -0
  157. vellum/workflows/state/context.py +18 -0
  158. vellum/workflows/state/encoder.py +57 -0
  159. vellum/workflows/state/store.py +28 -0
  160. vellum/workflows/state/tests/__init__.py +0 -0
  161. vellum/workflows/state/tests/test_state.py +113 -0
  162. vellum/workflows/types/__init__.py +0 -0
  163. vellum/workflows/types/core.py +91 -0
  164. vellum/workflows/types/generics.py +14 -0
  165. vellum/workflows/types/stack.py +39 -0
  166. vellum/workflows/types/tests/__init__.py +0 -0
  167. vellum/workflows/types/tests/test_utils.py +76 -0
  168. vellum/workflows/types/utils.py +164 -0
  169. vellum/workflows/utils/__init__.py +0 -0
  170. vellum/workflows/utils/names.py +13 -0
  171. vellum/workflows/utils/tests/__init__.py +0 -0
  172. vellum/workflows/utils/tests/test_names.py +15 -0
  173. vellum/workflows/utils/tests/test_vellum_variables.py +25 -0
  174. vellum/workflows/utils/vellum_variables.py +81 -0
  175. vellum/workflows/vellum_client.py +18 -0
  176. vellum/workflows/workflows/__init__.py +5 -0
  177. vellum/workflows/workflows/base.py +365 -0
  178. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/METADATA +2 -1
  179. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/RECORD +245 -7
  180. vellum_cli/__init__.py +72 -0
  181. vellum_cli/aliased_group.py +103 -0
  182. vellum_cli/config.py +96 -0
  183. vellum_cli/image_push.py +112 -0
  184. vellum_cli/logger.py +36 -0
  185. vellum_cli/pull.py +73 -0
  186. vellum_cli/push.py +121 -0
  187. vellum_cli/tests/test_config.py +100 -0
  188. vellum_cli/tests/test_pull.py +152 -0
  189. vellum_ee/workflows/__init__.py +0 -0
  190. vellum_ee/workflows/display/__init__.py +0 -0
  191. vellum_ee/workflows/display/base.py +73 -0
  192. vellum_ee/workflows/display/nodes/__init__.py +4 -0
  193. vellum_ee/workflows/display/nodes/base_node_display.py +116 -0
  194. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +36 -0
  195. vellum_ee/workflows/display/nodes/get_node_display_class.py +25 -0
  196. vellum_ee/workflows/display/nodes/tests/__init__.py +0 -0
  197. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +47 -0
  198. vellum_ee/workflows/display/nodes/types.py +18 -0
  199. vellum_ee/workflows/display/nodes/utils.py +33 -0
  200. vellum_ee/workflows/display/nodes/vellum/__init__.py +32 -0
  201. vellum_ee/workflows/display/nodes/vellum/api_node.py +205 -0
  202. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +71 -0
  203. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +217 -0
  204. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +61 -0
  205. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +49 -0
  206. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +170 -0
  207. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +99 -0
  208. vellum_ee/workflows/display/nodes/vellum/map_node.py +100 -0
  209. vellum_ee/workflows/display/nodes/vellum/merge_node.py +48 -0
  210. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +68 -0
  211. vellum_ee/workflows/display/nodes/vellum/search_node.py +193 -0
  212. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +58 -0
  213. vellum_ee/workflows/display/nodes/vellum/templating_node.py +67 -0
  214. vellum_ee/workflows/display/nodes/vellum/tests/__init__.py +0 -0
  215. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +106 -0
  216. vellum_ee/workflows/display/nodes/vellum/try_node.py +38 -0
  217. vellum_ee/workflows/display/nodes/vellum/utils.py +76 -0
  218. vellum_ee/workflows/display/tests/__init__.py +0 -0
  219. vellum_ee/workflows/display/tests/workflow_serialization/__init__.py +0 -0
  220. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +426 -0
  221. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +607 -0
  222. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +1175 -0
  223. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +235 -0
  224. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +511 -0
  225. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +372 -0
  226. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +272 -0
  227. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +289 -0
  228. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +354 -0
  229. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +123 -0
  230. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +84 -0
  231. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +233 -0
  232. vellum_ee/workflows/display/types.py +46 -0
  233. vellum_ee/workflows/display/utils/__init__.py +0 -0
  234. vellum_ee/workflows/display/utils/tests/__init__.py +0 -0
  235. vellum_ee/workflows/display/utils/tests/test_uuids.py +16 -0
  236. vellum_ee/workflows/display/utils/uuids.py +24 -0
  237. vellum_ee/workflows/display/utils/vellum.py +121 -0
  238. vellum_ee/workflows/display/vellum.py +357 -0
  239. vellum_ee/workflows/display/workflows/__init__.py +5 -0
  240. vellum_ee/workflows/display/workflows/base_workflow_display.py +302 -0
  241. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +32 -0
  242. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +386 -0
  243. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/LICENSE +0 -0
  244. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/WHEEL +0 -0
  245. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,147 @@
1
+ from collections import defaultdict
2
+ from queue import Empty, Queue
3
+ from threading import Thread
4
+ from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, overload
5
+
6
+ from vellum.workflows.descriptors.base import BaseDescriptor
7
+ from vellum.workflows.errors.types import VellumErrorCode
8
+ from vellum.workflows.exceptions import NodeException
9
+ from vellum.workflows.inputs.base import BaseInputs
10
+ from vellum.workflows.nodes.bases import BaseNode
11
+ from vellum.workflows.outputs import BaseOutputs
12
+ from vellum.workflows.state.base import BaseState
13
+ from vellum.workflows.types.generics import NodeType, StateType
14
+
15
+ if TYPE_CHECKING:
16
+ from vellum.workflows import BaseWorkflow
17
+ from vellum.workflows.events.workflow import WorkflowEvent
18
+
19
+ MapNodeItemType = TypeVar("MapNodeItemType")
20
+
21
+
22
+ class MapNode(BaseNode, Generic[StateType, MapNodeItemType]):
23
+ """
24
+ Used to map over a list of items and execute a Subworkflow on each iteration.
25
+
26
+ items: List[MapNodeItemType] - The items to map over
27
+ subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute on each iteration
28
+ concurrency: Optional[int] = None - The maximum number of concurrent subworkflow executions
29
+ """
30
+
31
+ items: List[MapNodeItemType]
32
+ subworkflow: Type["BaseWorkflow"]
33
+ concurrency: Optional[int] = None
34
+
35
+ class Outputs(BaseOutputs):
36
+ mapped_items: list
37
+
38
+ class SubworkflowInputs(BaseInputs):
39
+ # TODO: Both type: ignore's below are believed to be incorrect and both have the following error:
40
+ # Type variable "workflows.nodes.map_node.map_node.MapNodeItemType" is unbound
41
+ # https://app.shortcut.com/vellum/story/4118
42
+
43
+ item: MapNodeItemType # type: ignore[valid-type]
44
+ index: int
45
+ all_items: List[MapNodeItemType] # type: ignore[valid-type]
46
+
47
+ def run(self) -> Outputs:
48
+ mapped_items: Dict[str, List] = defaultdict(list)
49
+ for output_descripter in self.subworkflow.Outputs:
50
+ mapped_items[output_descripter.name] = [None] * len(self.items)
51
+
52
+ self._event_queue: Queue[Tuple[int, WorkflowEvent]] = Queue()
53
+ fulfilled_iterations: List[bool] = []
54
+ for index, item in enumerate(self.items):
55
+ fulfilled_iterations.append(False)
56
+ thread = Thread(target=self._run_subworkflow, kwargs={"item": item, "index": index})
57
+ thread.start()
58
+
59
+ try:
60
+ # We should consolidate this logic with the logic workflow runner uses
61
+ # https://app.shortcut.com/vellum/story/4736
62
+ while map_node_event := self._event_queue.get():
63
+ index = map_node_event[0]
64
+ terminal_event = map_node_event[1]
65
+
66
+ if terminal_event.name == "workflow.execution.fulfilled":
67
+ workflow_output_vars = vars(terminal_event.outputs)
68
+
69
+ for output_name in workflow_output_vars:
70
+ output_mapped_items = mapped_items[output_name]
71
+ output_mapped_items[index] = workflow_output_vars[output_name]
72
+
73
+ fulfilled_iterations[index] = True
74
+ if all(fulfilled_iterations):
75
+ break
76
+ elif terminal_event.name == "workflow.execution.paused":
77
+ raise NodeException(
78
+ code=VellumErrorCode.INVALID_OUTPUTS,
79
+ message=f"Subworkflow unexpectedly paused on iteration {index}",
80
+ )
81
+ elif terminal_event.name == "workflow.execution.rejected":
82
+ raise NodeException(
83
+ f"Subworkflow failed on iteration {index} with error: {terminal_event.error.message}",
84
+ code=terminal_event.error.code,
85
+ )
86
+ except Empty:
87
+ pass
88
+
89
+ return self.Outputs(**mapped_items)
90
+
91
+ def _run_subworkflow(self, *, item: MapNodeItemType, index: int) -> None:
92
+ subworkflow = self.subworkflow(parent_state=self.state)
93
+ events = subworkflow.stream(inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items))
94
+
95
+ for event in events:
96
+ self._event_queue.put((index, event))
97
+
98
+ @overload
99
+ @classmethod
100
+ def wrap(cls, items: List[MapNodeItemType]) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
101
+
102
+ # TODO: We should be able to do this overload automatically as we do with node attributes
103
+ # https://app.shortcut.com/vellum/story/5289
104
+ @overload
105
+ @classmethod
106
+ def wrap(
107
+ cls, items: BaseDescriptor[List[MapNodeItemType]]
108
+ ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]: ...
109
+
110
+ @classmethod
111
+ def wrap(
112
+ cls, items: Union[List[MapNodeItemType], BaseDescriptor[List[MapNodeItemType]]]
113
+ ) -> Callable[..., Type["MapNode[StateType, MapNodeItemType]"]]:
114
+ _items = items
115
+
116
+ def decorator(inner_cls: Type[NodeType]) -> Type["MapNode[StateType, MapNodeItemType]"]:
117
+ # Investigate how to use dependency injection to avoid circular imports
118
+ # https://app.shortcut.com/vellum/story/4116
119
+ from vellum.workflows import BaseWorkflow
120
+
121
+ class Subworkflow(BaseWorkflow[MapNode.SubworkflowInputs, BaseState]):
122
+ graph = inner_cls
123
+
124
+ # mypy is wrong here, this works and is defined
125
+ class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
126
+ pass
127
+
128
+ class WrappedNodeOutputs(BaseOutputs):
129
+ pass
130
+
131
+ WrappedNodeOutputs.__annotations__ = {
132
+ # TODO: We'll need to infer the type T of Subworkflow.Outputs[name] so we could do List[T] here
133
+ # https://app.shortcut.com/vellum/story/4119
134
+ descriptor.name: List
135
+ for descriptor in inner_cls.Outputs
136
+ }
137
+
138
+ class WrappedNode(MapNode[StateType, MapNodeItemType]):
139
+ items = _items
140
+ subworkflow = Subworkflow
141
+
142
+ class Outputs(WrappedNodeOutputs):
143
+ pass
144
+
145
+ return WrappedNode
146
+
147
+ return decorator
File without changes
@@ -0,0 +1,65 @@
1
+ import time
2
+
3
+ from vellum.workflows.inputs.base import BaseInputs
4
+ from vellum.workflows.nodes.bases import BaseNode
5
+ from vellum.workflows.nodes.core.map_node.node import MapNode
6
+ from vellum.workflows.outputs.base import BaseOutputs
7
+ from vellum.workflows.state.base import BaseState, StateMeta
8
+
9
+
10
+ def test_map_node__use_parent_inputs_and_state():
11
+ # GIVEN a parent workflow Inputs and State
12
+ class Inputs(BaseInputs):
13
+ foo: str
14
+
15
+ class State(BaseState):
16
+ bar: str
17
+
18
+ # AND a map node that is configured to use the parent's inputs and state
19
+ @MapNode.wrap(items=[1, 2, 3])
20
+ class TestNode(BaseNode):
21
+ item = MapNode.SubworkflowInputs.item
22
+ foo = Inputs.foo
23
+ bar = State.bar
24
+
25
+ class Outputs(BaseOutputs):
26
+ value: str
27
+
28
+ def run(self) -> Outputs:
29
+ return self.Outputs(value=f"{self.foo} {self.bar} {self.item}")
30
+
31
+ # WHEN the node is run
32
+ node = TestNode(
33
+ state=State(
34
+ bar="bar",
35
+ meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
36
+ )
37
+ )
38
+ outputs = node.run()
39
+
40
+ # THEN the data is used successfully
41
+ assert outputs.value == ["foo bar 1", "foo bar 2", "foo bar 3"]
42
+
43
+
44
+ def test_map_node__use_parallelism():
45
+ # GIVEN a map node that is configured to use the parent's inputs and state
46
+ @MapNode.wrap(items=list(range(10)))
47
+ class TestNode(BaseNode):
48
+ item = MapNode.SubworkflowInputs.item
49
+
50
+ class Outputs(BaseOutputs):
51
+ value: int
52
+
53
+ def run(self) -> Outputs:
54
+ time.sleep(0.03)
55
+ return self.Outputs(value=self.item + 1)
56
+
57
+ # WHEN the node is run
58
+ node = TestNode(state=BaseState())
59
+ start_ts = time.time_ns()
60
+ node.run()
61
+ end_ts = time.time_ns()
62
+
63
+ # THEN the node should have ran in parallel
64
+ run_time = (end_ts - start_ts) / 10**9
65
+ assert run_time < 0.1
@@ -0,0 +1,5 @@
1
+ from .node import RetryNode
2
+
3
+ __all__ = [
4
+ "RetryNode",
5
+ ]
@@ -0,0 +1,106 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type
2
+
3
+ from vellum.workflows.errors.types import VellumErrorCode
4
+ from vellum.workflows.exceptions import NodeException
5
+ from vellum.workflows.inputs.base import BaseInputs
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.nodes.bases.base import BaseNodeMeta
8
+ from vellum.workflows.state.base import BaseState
9
+ from vellum.workflows.types.generics import StateType
10
+
11
+ if TYPE_CHECKING:
12
+ from vellum.workflows import BaseWorkflow
13
+
14
+
15
+ class _RetryNodeMeta(BaseNodeMeta):
16
+ @property
17
+ def _localns(cls) -> Dict[str, Any]:
18
+ return {
19
+ **super()._localns,
20
+ "SubworkflowInputs": getattr(cls, "SubworkflowInputs"),
21
+ }
22
+
23
+
24
+ class RetryNode(BaseNode[StateType], Generic[StateType], metaclass=_RetryNodeMeta):
25
+ """
26
+ Used to retry a Subworkflow a specified number of times.
27
+
28
+ max_attempts: int - The maximum number of attempts to retry the Subworkflow
29
+ retry_on_error_code: Optional[VellumErrorCode] = None - The error code to retry on
30
+ subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"] - The Subworkflow to execute
31
+ """
32
+
33
+ max_attempts: int
34
+ retry_on_error_code: Optional[VellumErrorCode] = None
35
+ subworkflow: Type["BaseWorkflow[SubworkflowInputs, BaseState]"]
36
+
37
+ class SubworkflowInputs(BaseInputs):
38
+ attempt_number: int
39
+
40
+ def run(self) -> BaseNode.Outputs:
41
+ last_exception = Exception("max_attempts must be greater than 0")
42
+ for index in range(self.max_attempts):
43
+ attempt_number = index + 1
44
+ subworkflow = self.subworkflow(
45
+ parent_state=self.state,
46
+ )
47
+ terminal_event = subworkflow.run(
48
+ inputs=self.SubworkflowInputs(attempt_number=attempt_number),
49
+ )
50
+ if terminal_event.name == "workflow.execution.fulfilled":
51
+ node_outputs = self.Outputs()
52
+ workflow_output_vars = vars(terminal_event.outputs)
53
+
54
+ for output_name in workflow_output_vars:
55
+ setattr(node_outputs, output_name, workflow_output_vars[output_name])
56
+
57
+ return node_outputs
58
+ elif terminal_event.name == "workflow.execution.paused":
59
+ last_exception = NodeException(
60
+ code=VellumErrorCode.INVALID_OUTPUTS,
61
+ message=f"Subworkflow unexpectedly paused on attempt {attempt_number}",
62
+ )
63
+ break
64
+ elif self.retry_on_error_code and self.retry_on_error_code != terminal_event.error.code:
65
+ last_exception = NodeException(
66
+ code=VellumErrorCode.INVALID_OUTPUTS,
67
+ message=f"""Unexpected rejection on attempt {attempt_number}: {terminal_event.error.code.value}.
68
+ Message: {terminal_event.error.message}""",
69
+ )
70
+ break
71
+ else:
72
+ last_exception = Exception(terminal_event.error.message)
73
+
74
+ raise last_exception
75
+
76
+ @classmethod
77
+ def wrap(
78
+ cls, max_attempts: int, retry_on_error_code: Optional[VellumErrorCode] = None
79
+ ) -> Callable[..., Type["RetryNode"]]:
80
+ _max_attempts = max_attempts
81
+ _retry_on_error_code = retry_on_error_code
82
+
83
+ def decorator(inner_cls: Type[BaseNode]) -> Type["RetryNode"]:
84
+ # Investigate how to use dependency injection to avoid circular imports
85
+ # https://app.shortcut.com/vellum/story/4116
86
+ from vellum.workflows import BaseWorkflow
87
+
88
+ class Subworkflow(BaseWorkflow[RetryNode.SubworkflowInputs, BaseState]):
89
+ graph = inner_cls
90
+
91
+ # mypy is wrong here, this works and is defined
92
+ class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
93
+ pass
94
+
95
+ class WrappedNode(RetryNode[StateType]):
96
+ max_attempts = _max_attempts
97
+ retry_on_error_code = _retry_on_error_code
98
+
99
+ subworkflow = Subworkflow
100
+
101
+ class Outputs(Subworkflow.Outputs):
102
+ pass
103
+
104
+ return WrappedNode
105
+
106
+ return decorator
@@ -0,0 +1,93 @@
1
+ import pytest
2
+
3
+ from vellum.workflows.errors.types import VellumErrorCode
4
+ from vellum.workflows.exceptions import NodeException
5
+ from vellum.workflows.inputs.base import BaseInputs
6
+ from vellum.workflows.nodes.bases import BaseNode
7
+ from vellum.workflows.nodes.core.retry_node.node import RetryNode
8
+ from vellum.workflows.outputs import BaseOutputs
9
+ from vellum.workflows.state.base import BaseState, StateMeta
10
+
11
+
12
+ def test_retry_node__retry_on_error_code__successfully_retried():
13
+ # GIVEN a retry node that is configured to retry on PROVIDER_ERROR
14
+ @RetryNode.wrap(max_attempts=3, retry_on_error_code=VellumErrorCode.PROVIDER_ERROR)
15
+ class TestNode(BaseNode):
16
+ attempt_number = RetryNode.SubworkflowInputs.attempt_number
17
+
18
+ class Outputs(BaseOutputs):
19
+ execution_count: int
20
+
21
+ def run(self) -> Outputs:
22
+ if self.attempt_number < 3:
23
+ raise NodeException(message="This will be retried", code=VellumErrorCode.PROVIDER_ERROR)
24
+
25
+ return self.Outputs(execution_count=self.attempt_number)
26
+
27
+ # WHEN the node is run and throws a PROVIDER_ERROR
28
+ node = TestNode(state=BaseState())
29
+ outputs = node.run()
30
+
31
+ # THEN the exception is retried
32
+ assert outputs.execution_count == 3
33
+
34
+
35
+ def test_retry_node__retry_on_error_code__missed():
36
+ # GIVEN a retry node that is configured to retry on PROVIDER_ERROR
37
+ @RetryNode.wrap(max_attempts=3, retry_on_error_code=VellumErrorCode.PROVIDER_ERROR)
38
+ class TestNode(BaseNode):
39
+ attempt_number = RetryNode.SubworkflowInputs.attempt_number
40
+
41
+ class Outputs(BaseOutputs):
42
+ execution_count: int
43
+
44
+ def run(self) -> Outputs:
45
+ if self.attempt_number < 3:
46
+ raise Exception("This will not be retried")
47
+
48
+ return self.Outputs(execution_count=self.attempt_number)
49
+
50
+ # WHEN the node is run and throws a different exception
51
+ node = TestNode(state=BaseState())
52
+ with pytest.raises(NodeException) as exc_info:
53
+ node.run()
54
+
55
+ # THEN the exception is not retried
56
+ assert (
57
+ exc_info.value.message
58
+ == "Unexpected rejection on attempt 1: INTERNAL_ERROR.\nMessage: This will not be retried"
59
+ )
60
+ assert exc_info.value.code == VellumErrorCode.INVALID_OUTPUTS
61
+
62
+
63
+ def test_retry_node__use_parent_inputs_and_state():
64
+ # GIVEN a parent workflow Inputs and State
65
+ class Inputs(BaseInputs):
66
+ foo: str
67
+
68
+ class State(BaseState):
69
+ bar: str
70
+
71
+ # AND a retry node that uses the parent's inputs and state
72
+ @RetryNode.wrap(max_attempts=3, retry_on_error_code=VellumErrorCode.PROVIDER_ERROR)
73
+ class TestNode(BaseNode):
74
+ foo = Inputs.foo
75
+ bar = State.bar
76
+
77
+ class Outputs(BaseOutputs):
78
+ value: str
79
+
80
+ def run(self) -> Outputs:
81
+ return self.Outputs(value=f"{self.foo} {self.bar}")
82
+
83
+ # WHEN the node is run
84
+ node = TestNode(
85
+ state=State(
86
+ bar="bar",
87
+ meta=StateMeta(workflow_inputs=Inputs(foo="foo")),
88
+ )
89
+ )
90
+ outputs = node.run()
91
+
92
+ # THEN the data is used successfully
93
+ assert outputs.value == "foo bar"
@@ -0,0 +1,5 @@
1
+ from .node import TemplatingNode
2
+
3
+ __all__ = [
4
+ "TemplatingNode",
5
+ ]
@@ -0,0 +1,12 @@
1
+ import json
2
+ from typing import Union
3
+
4
+
5
+ def is_valid_json_string(value: Union[str, bytes]) -> bool:
6
+ """Determines whether the given value is a valid JSON string."""
7
+
8
+ try:
9
+ json.loads(value)
10
+ except ValueError:
11
+ return False
12
+ return True
@@ -0,0 +1,2 @@
1
+ class JinjaTemplateError(Exception):
2
+ pass
@@ -0,0 +1,123 @@
1
+ import datetime
2
+ import itertools
3
+ import json
4
+ import random
5
+ import re
6
+ from typing import Any, Callable, ClassVar, Dict, Generic, Mapping, Tuple, Type, TypeVar, Union, get_args
7
+
8
+ import dateutil.parser
9
+ import pydash
10
+ import pytz
11
+ import yaml
12
+
13
+ from vellum.workflows.errors import VellumErrorCode
14
+ from vellum.workflows.exceptions import NodeException
15
+ from vellum.workflows.nodes.bases import BaseNode
16
+ from vellum.workflows.nodes.bases.base import BaseNodeMeta
17
+ from vellum.workflows.nodes.core.templating_node.custom_filters import is_valid_json_string
18
+ from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
19
+ from vellum.workflows.nodes.core.templating_node.render import render_sandboxed_jinja_template
20
+ from vellum.workflows.types.core import EntityInputsInterface
21
+ from vellum.workflows.types.generics import StateType
22
+ from vellum.workflows.types.utils import get_original_base
23
+
24
+ _DEFAULT_JINJA_GLOBALS: Dict[str, Any] = {
25
+ "datetime": datetime,
26
+ "dateutil": dateutil,
27
+ "itertools": itertools,
28
+ "json": json,
29
+ "pydash": pydash,
30
+ "pytz": pytz,
31
+ "random": random,
32
+ "re": re,
33
+ "yaml": yaml,
34
+ }
35
+
36
+ _DEFAULT_JINJA_CUSTOM_FILTERS: Dict[str, Callable[[Union[str, bytes]], bool]] = {
37
+ "is_valid_json_string": is_valid_json_string,
38
+ }
39
+
40
+ _OutputType = TypeVar("_OutputType")
41
+
42
+
43
+ # TODO: Consolidate all dynamic output metaclasses
44
+ # https://app.shortcut.com/vellum/story/5533
45
+ class _TemplatingNodeMeta(BaseNodeMeta):
46
+ def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
47
+ parent = super().__new__(mcs, name, bases, dct)
48
+
49
+ if not isinstance(parent, _TemplatingNodeMeta):
50
+ raise ValueError("TemplatingNode must be created with the TemplatingNodeMeta metaclass")
51
+
52
+ parent.__dict__["Outputs"].__annotations__["result"] = parent.get_output_type()
53
+ return parent
54
+
55
+ def get_output_type(cls) -> Type:
56
+ original_base = get_original_base(cls)
57
+ all_args = get_args(original_base)
58
+
59
+ if len(all_args) < 2 or isinstance(all_args[1], TypeVar):
60
+ return str
61
+ else:
62
+ return all_args[1]
63
+
64
+
65
+ class TemplatingNode(BaseNode[StateType], Generic[StateType, _OutputType], metaclass=_TemplatingNodeMeta):
66
+ """Used to render a Jinja template.
67
+
68
+ Useful for lightweight data transformations and complex string templating.
69
+ """
70
+
71
+ # The Jinja template to render.
72
+ template: ClassVar[str]
73
+
74
+ # The inputs to render the template with.
75
+ inputs: ClassVar[EntityInputsInterface]
76
+
77
+ jinja_globals: Dict[str, Any] = _DEFAULT_JINJA_GLOBALS
78
+ jinja_custom_filters: Mapping[str, Callable[[Union[str, bytes]], bool]] = _DEFAULT_JINJA_CUSTOM_FILTERS
79
+
80
+ class Outputs(BaseNode.Outputs):
81
+ # We use our mypy plugin to override the _OutputType with the actual output type
82
+ # for downstream references to this output.
83
+ result: _OutputType # type: ignore[valid-type]
84
+
85
+ def _cast_rendered_template(self, rendered_template: str) -> Any:
86
+ original_base = get_original_base(self.__class__)
87
+ all_args = get_args(original_base)
88
+
89
+ if len(all_args) < 2 or isinstance(all_args[1], TypeVar):
90
+ output_type = str
91
+ else:
92
+ output_type = all_args[1]
93
+
94
+ if output_type is str:
95
+ return rendered_template
96
+
97
+ if output_type is float:
98
+ return float(rendered_template)
99
+
100
+ if output_type is int:
101
+ return int(rendered_template)
102
+
103
+ if output_type is bool:
104
+ return bool(rendered_template)
105
+
106
+ raise ValueError(f"Unsupported output type: {output_type}")
107
+
108
+ def run(self) -> Outputs:
109
+ rendered_template = self._render_template()
110
+ result = self._cast_rendered_template(rendered_template)
111
+
112
+ return self.Outputs(result=result)
113
+
114
+ def _render_template(self) -> str:
115
+ try:
116
+ return render_sandboxed_jinja_template(
117
+ template=self.template,
118
+ input_values=self.inputs,
119
+ jinja_custom_filters={**self.jinja_custom_filters},
120
+ jinja_globals=self.jinja_globals,
121
+ )
122
+ except JinjaTemplateError as e:
123
+ raise NodeException(message=str(e), code=VellumErrorCode.INVALID_TEMPLATE)
@@ -0,0 +1,55 @@
1
+ import json
2
+ from typing import Any, Callable, Dict, Optional, Union
3
+
4
+ from jinja2.sandbox import SandboxedEnvironment
5
+
6
+ from vellum.workflows.nodes.core.templating_node.exceptions import JinjaTemplateError
7
+ from vellum.workflows.state.encoder import DefaultStateEncoder
8
+
9
+
10
+ def finalize(obj: Any) -> str:
11
+ if isinstance(obj, dict):
12
+ return json.dumps(obj, cls=DefaultStateEncoder)
13
+
14
+ return str(obj)
15
+
16
+
17
+ def render_sandboxed_jinja_template(
18
+ *,
19
+ template: str,
20
+ input_values: Dict[str, Any],
21
+ jinja_custom_filters: Optional[Dict[str, Callable[[Union[str, bytes]], bool]]] = None,
22
+ jinja_globals: Optional[Dict[str, Any]] = None,
23
+ ) -> str:
24
+ """Render a Jinja template within a sandboxed environment."""
25
+
26
+ try:
27
+ environment = SandboxedEnvironment(
28
+ keep_trailing_newline=True,
29
+ finalize=finalize,
30
+ )
31
+
32
+ if jinja_custom_filters:
33
+ environment.filters.update(jinja_custom_filters)
34
+
35
+ jinja_template = environment.from_string(template)
36
+
37
+ if jinja_globals:
38
+ jinja_template.globals.update(jinja_globals)
39
+
40
+ rendered_template = jinja_template.render(input_values)
41
+ except json.JSONDecodeError as e:
42
+ if e.msg == "Invalid control character at":
43
+ raise JinjaTemplateError(
44
+ "Unable to render jinja template:\n"
45
+ "Cannot run json.loads() on JSON containing control characters. "
46
+ "Use json.loads(input, strict=False) instead.",
47
+ )
48
+
49
+ raise JinjaTemplateError(
50
+ f"Unable to render jinja template:\nCannot run json.loads() on invalid JSON\n{e.args[0]}"
51
+ )
52
+ except Exception as e:
53
+ raise JinjaTemplateError(f"Unable to render jinja template:\n{e.args[0]}")
54
+
55
+ return rendered_template
@@ -0,0 +1,21 @@
1
+ import json
2
+
3
+ from vellum.workflows.nodes.core.templating_node.node import TemplatingNode
4
+
5
+
6
+ def test_templating_node__dict_output():
7
+ # GIVEN a templating node with a dict input that just returns it
8
+ class TemplateNode(TemplatingNode):
9
+ template = "{{ data }}"
10
+ inputs = {
11
+ "data": {
12
+ "key": "value",
13
+ }
14
+ }
15
+
16
+ # WHEN the node is run
17
+ node = TemplateNode()
18
+ outputs = node.run()
19
+
20
+ # THEN the output is json serializable
21
+ assert json.loads(outputs.result) == {"key": "value"}
@@ -0,0 +1,5 @@
1
+ from .node import TryNode
2
+
3
+ __all__ = [
4
+ "TryNode",
5
+ ]