vellum-ai 0.9.16rc2__py3-none-any.whl → 0.9.16rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. vellum/plugins/__init__.py +0 -0
  2. vellum/plugins/pydantic.py +74 -0
  3. vellum/plugins/utils.py +19 -0
  4. vellum/plugins/vellum_mypy.py +639 -3
  5. vellum/workflows/README.md +90 -0
  6. vellum/workflows/__init__.py +5 -0
  7. vellum/workflows/constants.py +43 -0
  8. vellum/workflows/descriptors/__init__.py +0 -0
  9. vellum/workflows/descriptors/base.py +339 -0
  10. vellum/workflows/descriptors/tests/test_utils.py +83 -0
  11. vellum/workflows/descriptors/utils.py +90 -0
  12. vellum/workflows/edges/__init__.py +5 -0
  13. vellum/workflows/edges/edge.py +23 -0
  14. vellum/workflows/emitters/__init__.py +5 -0
  15. vellum/workflows/emitters/base.py +14 -0
  16. vellum/workflows/environment/__init__.py +5 -0
  17. vellum/workflows/environment/environment.py +7 -0
  18. vellum/workflows/errors/__init__.py +6 -0
  19. vellum/workflows/errors/types.py +20 -0
  20. vellum/workflows/events/__init__.py +31 -0
  21. vellum/workflows/events/node.py +125 -0
  22. vellum/workflows/events/tests/__init__.py +0 -0
  23. vellum/workflows/events/tests/test_event.py +216 -0
  24. vellum/workflows/events/types.py +52 -0
  25. vellum/workflows/events/utils.py +5 -0
  26. vellum/workflows/events/workflow.py +139 -0
  27. vellum/workflows/exceptions.py +15 -0
  28. vellum/workflows/expressions/__init__.py +0 -0
  29. vellum/workflows/expressions/accessor.py +52 -0
  30. vellum/workflows/expressions/and_.py +32 -0
  31. vellum/workflows/expressions/begins_with.py +31 -0
  32. vellum/workflows/expressions/between.py +38 -0
  33. vellum/workflows/expressions/coalesce_expression.py +41 -0
  34. vellum/workflows/expressions/contains.py +30 -0
  35. vellum/workflows/expressions/does_not_begin_with.py +31 -0
  36. vellum/workflows/expressions/does_not_contain.py +30 -0
  37. vellum/workflows/expressions/does_not_end_with.py +31 -0
  38. vellum/workflows/expressions/does_not_equal.py +25 -0
  39. vellum/workflows/expressions/ends_with.py +31 -0
  40. vellum/workflows/expressions/equals.py +25 -0
  41. vellum/workflows/expressions/greater_than.py +33 -0
  42. vellum/workflows/expressions/greater_than_or_equal_to.py +33 -0
  43. vellum/workflows/expressions/in_.py +31 -0
  44. vellum/workflows/expressions/is_blank.py +24 -0
  45. vellum/workflows/expressions/is_not_blank.py +24 -0
  46. vellum/workflows/expressions/is_not_null.py +21 -0
  47. vellum/workflows/expressions/is_not_undefined.py +22 -0
  48. vellum/workflows/expressions/is_null.py +21 -0
  49. vellum/workflows/expressions/is_undefined.py +22 -0
  50. vellum/workflows/expressions/less_than.py +33 -0
  51. vellum/workflows/expressions/less_than_or_equal_to.py +33 -0
  52. vellum/workflows/expressions/not_between.py +38 -0
  53. vellum/workflows/expressions/not_in.py +31 -0
  54. vellum/workflows/expressions/or_.py +32 -0
  55. vellum/workflows/graph/__init__.py +3 -0
  56. vellum/workflows/graph/graph.py +131 -0
  57. vellum/workflows/graph/tests/__init__.py +0 -0
  58. vellum/workflows/graph/tests/test_graph.py +437 -0
  59. vellum/workflows/inputs/__init__.py +5 -0
  60. vellum/workflows/inputs/base.py +55 -0
  61. vellum/workflows/logging.py +14 -0
  62. vellum/workflows/nodes/__init__.py +46 -0
  63. vellum/workflows/nodes/bases/__init__.py +7 -0
  64. vellum/workflows/nodes/bases/base.py +332 -0
  65. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +5 -0
  66. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +10 -0
  67. vellum/workflows/nodes/bases/tests/__init__.py +0 -0
  68. vellum/workflows/nodes/bases/tests/test_base_node.py +125 -0
  69. vellum/workflows/nodes/core/__init__.py +16 -0
  70. vellum/workflows/nodes/core/error_node/__init__.py +5 -0
  71. vellum/workflows/nodes/core/error_node/node.py +26 -0
  72. vellum/workflows/nodes/core/inline_subworkflow_node/__init__.py +5 -0
  73. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +73 -0
  74. vellum/workflows/nodes/core/map_node/__init__.py +5 -0
  75. vellum/workflows/nodes/core/map_node/node.py +147 -0
  76. vellum/workflows/nodes/core/map_node/tests/__init__.py +0 -0
  77. vellum/workflows/nodes/core/map_node/tests/test_node.py +65 -0
  78. vellum/workflows/nodes/core/retry_node/__init__.py +5 -0
  79. vellum/workflows/nodes/core/retry_node/node.py +106 -0
  80. vellum/workflows/nodes/core/retry_node/tests/__init__.py +0 -0
  81. vellum/workflows/nodes/core/retry_node/tests/test_node.py +93 -0
  82. vellum/workflows/nodes/core/templating_node/__init__.py +5 -0
  83. vellum/workflows/nodes/core/templating_node/custom_filters.py +12 -0
  84. vellum/workflows/nodes/core/templating_node/exceptions.py +2 -0
  85. vellum/workflows/nodes/core/templating_node/node.py +123 -0
  86. vellum/workflows/nodes/core/templating_node/render.py +55 -0
  87. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  88. vellum/workflows/nodes/core/try_node/__init__.py +5 -0
  89. vellum/workflows/nodes/core/try_node/node.py +110 -0
  90. vellum/workflows/nodes/core/try_node/tests/__init__.py +0 -0
  91. vellum/workflows/nodes/core/try_node/tests/test_node.py +82 -0
  92. vellum/workflows/nodes/displayable/__init__.py +31 -0
  93. vellum/workflows/nodes/displayable/api_node/__init__.py +5 -0
  94. vellum/workflows/nodes/displayable/api_node/node.py +44 -0
  95. vellum/workflows/nodes/displayable/bases/__init__.py +11 -0
  96. vellum/workflows/nodes/displayable/bases/api_node/__init__.py +5 -0
  97. vellum/workflows/nodes/displayable/bases/api_node/node.py +70 -0
  98. vellum/workflows/nodes/displayable/bases/base_prompt_node/__init__.py +5 -0
  99. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +60 -0
  100. vellum/workflows/nodes/displayable/bases/inline_prompt_node/__init__.py +5 -0
  101. vellum/workflows/nodes/displayable/bases/inline_prompt_node/constants.py +13 -0
  102. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +118 -0
  103. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +98 -0
  104. vellum/workflows/nodes/displayable/bases/search_node.py +90 -0
  105. vellum/workflows/nodes/displayable/code_execution_node/__init__.py +5 -0
  106. vellum/workflows/nodes/displayable/code_execution_node/node.py +197 -0
  107. vellum/workflows/nodes/displayable/code_execution_node/tests/__init__.py +0 -0
  108. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/__init__.py +0 -0
  109. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/main.py +3 -0
  110. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +111 -0
  111. vellum/workflows/nodes/displayable/code_execution_node/utils.py +10 -0
  112. vellum/workflows/nodes/displayable/conditional_node/__init__.py +5 -0
  113. vellum/workflows/nodes/displayable/conditional_node/node.py +25 -0
  114. vellum/workflows/nodes/displayable/final_output_node/__init__.py +5 -0
  115. vellum/workflows/nodes/displayable/final_output_node/node.py +43 -0
  116. vellum/workflows/nodes/displayable/guardrail_node/__init__.py +5 -0
  117. vellum/workflows/nodes/displayable/guardrail_node/node.py +97 -0
  118. vellum/workflows/nodes/displayable/inline_prompt_node/__init__.py +5 -0
  119. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +41 -0
  120. vellum/workflows/nodes/displayable/merge_node/__init__.py +5 -0
  121. vellum/workflows/nodes/displayable/merge_node/node.py +10 -0
  122. vellum/workflows/nodes/displayable/prompt_deployment_node/__init__.py +5 -0
  123. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +45 -0
  124. vellum/workflows/nodes/displayable/search_node/__init__.py +5 -0
  125. vellum/workflows/nodes/displayable/search_node/node.py +26 -0
  126. vellum/workflows/nodes/displayable/subworkflow_deployment_node/__init__.py +5 -0
  127. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +156 -0
  128. vellum/workflows/nodes/displayable/tests/__init__.py +0 -0
  129. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +148 -0
  130. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +134 -0
  131. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +80 -0
  132. vellum/workflows/nodes/utils.py +27 -0
  133. vellum/workflows/outputs/__init__.py +6 -0
  134. vellum/workflows/outputs/base.py +196 -0
  135. vellum/workflows/ports/__init__.py +7 -0
  136. vellum/workflows/ports/node_ports.py +75 -0
  137. vellum/workflows/ports/port.py +75 -0
  138. vellum/workflows/ports/utils.py +40 -0
  139. vellum/workflows/references/__init__.py +17 -0
  140. vellum/workflows/references/environment_variable.py +20 -0
  141. vellum/workflows/references/execution_count.py +20 -0
  142. vellum/workflows/references/external_input.py +49 -0
  143. vellum/workflows/references/input.py +7 -0
  144. vellum/workflows/references/lazy.py +55 -0
  145. vellum/workflows/references/node.py +43 -0
  146. vellum/workflows/references/output.py +78 -0
  147. vellum/workflows/references/state_value.py +23 -0
  148. vellum/workflows/references/vellum_secret.py +15 -0
  149. vellum/workflows/references/workflow_input.py +41 -0
  150. vellum/workflows/resolvers/__init__.py +5 -0
  151. vellum/workflows/resolvers/base.py +15 -0
  152. vellum/workflows/runner/__init__.py +5 -0
  153. vellum/workflows/runner/runner.py +588 -0
  154. vellum/workflows/runner/types.py +18 -0
  155. vellum/workflows/state/__init__.py +5 -0
  156. vellum/workflows/state/base.py +327 -0
  157. vellum/workflows/state/context.py +18 -0
  158. vellum/workflows/state/encoder.py +57 -0
  159. vellum/workflows/state/store.py +28 -0
  160. vellum/workflows/state/tests/__init__.py +0 -0
  161. vellum/workflows/state/tests/test_state.py +113 -0
  162. vellum/workflows/types/__init__.py +0 -0
  163. vellum/workflows/types/core.py +91 -0
  164. vellum/workflows/types/generics.py +14 -0
  165. vellum/workflows/types/stack.py +39 -0
  166. vellum/workflows/types/tests/__init__.py +0 -0
  167. vellum/workflows/types/tests/test_utils.py +76 -0
  168. vellum/workflows/types/utils.py +164 -0
  169. vellum/workflows/utils/__init__.py +0 -0
  170. vellum/workflows/utils/names.py +13 -0
  171. vellum/workflows/utils/tests/__init__.py +0 -0
  172. vellum/workflows/utils/tests/test_names.py +15 -0
  173. vellum/workflows/utils/tests/test_vellum_variables.py +25 -0
  174. vellum/workflows/utils/vellum_variables.py +81 -0
  175. vellum/workflows/vellum_client.py +18 -0
  176. vellum/workflows/workflows/__init__.py +5 -0
  177. vellum/workflows/workflows/base.py +365 -0
  178. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/METADATA +2 -1
  179. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/RECORD +245 -7
  180. vellum_cli/__init__.py +72 -0
  181. vellum_cli/aliased_group.py +103 -0
  182. vellum_cli/config.py +96 -0
  183. vellum_cli/image_push.py +112 -0
  184. vellum_cli/logger.py +36 -0
  185. vellum_cli/pull.py +73 -0
  186. vellum_cli/push.py +121 -0
  187. vellum_cli/tests/test_config.py +100 -0
  188. vellum_cli/tests/test_pull.py +152 -0
  189. vellum_ee/workflows/__init__.py +0 -0
  190. vellum_ee/workflows/display/__init__.py +0 -0
  191. vellum_ee/workflows/display/base.py +73 -0
  192. vellum_ee/workflows/display/nodes/__init__.py +4 -0
  193. vellum_ee/workflows/display/nodes/base_node_display.py +116 -0
  194. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +36 -0
  195. vellum_ee/workflows/display/nodes/get_node_display_class.py +25 -0
  196. vellum_ee/workflows/display/nodes/tests/__init__.py +0 -0
  197. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +47 -0
  198. vellum_ee/workflows/display/nodes/types.py +18 -0
  199. vellum_ee/workflows/display/nodes/utils.py +33 -0
  200. vellum_ee/workflows/display/nodes/vellum/__init__.py +32 -0
  201. vellum_ee/workflows/display/nodes/vellum/api_node.py +205 -0
  202. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +71 -0
  203. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +217 -0
  204. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +61 -0
  205. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +49 -0
  206. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +170 -0
  207. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +99 -0
  208. vellum_ee/workflows/display/nodes/vellum/map_node.py +100 -0
  209. vellum_ee/workflows/display/nodes/vellum/merge_node.py +48 -0
  210. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +68 -0
  211. vellum_ee/workflows/display/nodes/vellum/search_node.py +193 -0
  212. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +58 -0
  213. vellum_ee/workflows/display/nodes/vellum/templating_node.py +67 -0
  214. vellum_ee/workflows/display/nodes/vellum/tests/__init__.py +0 -0
  215. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +106 -0
  216. vellum_ee/workflows/display/nodes/vellum/try_node.py +38 -0
  217. vellum_ee/workflows/display/nodes/vellum/utils.py +76 -0
  218. vellum_ee/workflows/display/tests/__init__.py +0 -0
  219. vellum_ee/workflows/display/tests/workflow_serialization/__init__.py +0 -0
  220. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +426 -0
  221. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +607 -0
  222. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +1175 -0
  223. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +235 -0
  224. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +511 -0
  225. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +372 -0
  226. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +272 -0
  227. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +289 -0
  228. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +354 -0
  229. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +123 -0
  230. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +84 -0
  231. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +233 -0
  232. vellum_ee/workflows/display/types.py +46 -0
  233. vellum_ee/workflows/display/utils/__init__.py +0 -0
  234. vellum_ee/workflows/display/utils/tests/__init__.py +0 -0
  235. vellum_ee/workflows/display/utils/tests/test_uuids.py +16 -0
  236. vellum_ee/workflows/display/utils/uuids.py +24 -0
  237. vellum_ee/workflows/display/utils/vellum.py +121 -0
  238. vellum_ee/workflows/display/vellum.py +357 -0
  239. vellum_ee/workflows/display/workflows/__init__.py +5 -0
  240. vellum_ee/workflows/display/workflows/base_workflow_display.py +302 -0
  241. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +32 -0
  242. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +386 -0
  243. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/LICENSE +0 -0
  244. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/WHEEL +0 -0
  245. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.9.16rc4.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,646 @@
1
- from typing import Type
1
+ import re
2
+ from typing import Callable, Dict, List, Optional, Set, Type
2
3
 
3
- from mypy.plugin import Plugin
4
+ from mypy.nodes import (
5
+ AssignmentStmt,
6
+ CallExpr,
7
+ Decorator,
8
+ MemberExpr,
9
+ NameExpr,
10
+ OverloadedFuncDef,
11
+ SymbolTableNode,
12
+ TypeAlias as MypyTypeAlias,
13
+ TypeInfo,
14
+ Var,
15
+ )
16
+ from mypy.options import Options
17
+ from mypy.plugin import AttributeContext, ClassDefContext, FunctionSigContext, MethodContext, Plugin
18
+ from mypy.types import AnyType, CallableType, FunctionLike, Instance, Type as MypyType, TypeAliasType, UnionType
19
+
20
+ TypeResolver = Callable[[str, List[MypyType]], MypyType]
21
+
22
+ DESCRIPTOR_PATHS: list[tuple[str, str, str]] = [
23
+ (
24
+ "vellum.workflows.outputs.base.BaseOutputs",
25
+ "vellum.workflows.references.output.OutputReference",
26
+ r"^[^_].*$",
27
+ ),
28
+ (
29
+ "vellum.workflows.nodes.bases.base.BaseNode.ExternalInputs",
30
+ "vellum.workflows.references.external_input.ExternalInputReference",
31
+ r"^[^_].*$",
32
+ ),
33
+ (
34
+ "vellum.workflows.nodes.bases.base.BaseNode.Execution",
35
+ "vellum.workflows.references.execution_count.ExecutionCountReference",
36
+ r"^count$",
37
+ ),
38
+ (
39
+ "vellum.workflows.state.base.BaseState",
40
+ "vellum.workflows.references.state_value.StateValueReference",
41
+ r"^[^_].*$",
42
+ ),
43
+ (
44
+ "vellum.workflows.inputs.base.BaseInputs",
45
+ "vellum.workflows.references.workflow_input.WorkflowInputReference",
46
+ r"^[^_].*$",
47
+ ),
48
+ (
49
+ "vellum.workflows.nodes.bases.base.BaseNode",
50
+ "vellum.workflows.references.node.NodeReference",
51
+ r"^[a-z].*$",
52
+ ),
53
+ ]
54
+
55
+
56
+ def _is_subclass(type_info: Optional[TypeInfo], fullname: str) -> bool:
57
+ if not type_info:
58
+ return False
59
+
60
+ if type_info.fullname == fullname:
61
+ return True
62
+
63
+ return any(base.type.fullname == fullname or _is_subclass(base.type, fullname) for base in type_info.bases)
64
+
65
+
66
+ def _get_attribute_mypy_type(type_info: TypeInfo, attribute_name: str) -> Optional[MypyType]:
67
+ type_node = type_info.names.get(attribute_name)
68
+ if type_node:
69
+ return type_node.type
70
+
71
+ bases = type_info.bases
72
+ for base in bases:
73
+ mypy_type = _get_attribute_mypy_type(base.type, attribute_name)
74
+ if mypy_type:
75
+ return mypy_type
76
+
77
+ return None
4
78
 
5
79
 
6
80
  class VellumMypyPlugin(Plugin):
7
- pass
81
+ """
82
+ This plugin is responsible for properly supporting types for all of the magic we
83
+ do with Descriptors and this library in general.
84
+ """
85
+
86
+ def __init__(self, options: Options) -> None:
87
+ """
88
+ Mypy performs its analyses in two phases: semantic analysis and type checking.
89
+ So we initialize state that we need to use across both phases here.
90
+
91
+ - `_calls_with_nested_descriptor_expressions`: A set of instances that are defined
92
+ within a BaseNode that reference a descriptor. Classes (e.g. Dataclasses) don't
93
+ support Descriptors by default, so we first need to ensure we're in the context of
94
+ a `BaseNode`, before then editing the signature to support Descriptors.
95
+
96
+ - `_nested_descriptor_expressions`: A mapping of MemberExprs that point to a
97
+ Descriptor to a callable that will resolve the type of the Descriptor. This field
98
+ is used in combination with `_calls_with_nested_descriptor_expressions` to
99
+ edit the signature of the function when a Descriptor is used as an argument.
100
+ """
101
+
102
+ self._calls_with_nested_descriptor_expressions: Set[CallExpr] = set()
103
+ self._nested_descriptor_expressions: Dict[MemberExpr, TypeResolver] = {}
104
+
105
+ super().__init__(options)
106
+
107
+ def get_class_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], MypyType]]:
108
+ """
109
+ This hook is used whenever we're accessing an attribute of a class. e.g. `MyClass.my_attribute`.
110
+
111
+ We use it to replace all special class attribute references with our descriptors.
112
+
113
+ TODO: We still need to support all other descriptors besides Outputs.
114
+ https://app.shortcut.com/vellum/story/4768
115
+ """
116
+ return self._class_attribute_hook
117
+
118
+ def _class_attribute_hook(self, ctx: AttributeContext) -> MypyType:
119
+ if not isinstance(ctx.type, CallableType):
120
+ return ctx.default_attr_type
121
+
122
+ if not isinstance(ctx.type.ret_type, Instance):
123
+ return ctx.default_attr_type
124
+
125
+ if not isinstance(ctx.context, MemberExpr):
126
+ return ctx.default_attr_type
127
+
128
+ for base_path, descriptor_path, attribute_regex in DESCRIPTOR_PATHS:
129
+ if not re.match(attribute_regex, ctx.context.name):
130
+ continue
131
+
132
+ if not _is_subclass(ctx.type.ret_type.type, base_path):
133
+ continue
134
+
135
+ symbol = ctx.type.ret_type.type.names.get(ctx.context.name)
136
+ if symbol and isinstance(symbol.node, (Decorator, OverloadedFuncDef)):
137
+ continue
138
+
139
+ info = self.lookup_fully_qualified(descriptor_path)
140
+ if info and isinstance(info.node, TypeInfo):
141
+ return Instance(info.node, [self._resolve_descriptor_type(ctx.default_attr_type)])
142
+
143
+ return ctx.default_attr_type
144
+
145
+ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
146
+ """
147
+ This hook is used whenever we're defining a class. e.g. `class MyClass(BaseNode): ...`.
148
+
149
+ We add support for any special Base Class we define in our project.
150
+ """
151
+
152
+ return self._base_class_hook
153
+
154
+ def _base_class_hook(self, ctx: ClassDefContext) -> None:
155
+ if _is_subclass(ctx.cls.info, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
156
+ self._dynamic_output_node_class_hook(ctx, "result")
157
+ elif _is_subclass(ctx.cls.info, "vellum.workflows.nodes.displayable.code_execution_node.node.CodeExecutionNode"):
158
+ self._dynamic_output_node_class_hook(ctx, "result")
159
+ elif _is_subclass(ctx.cls.info, "vellum.workflows.nodes.displayable.final_output_node.node.FinalOutputNode"):
160
+ self._dynamic_output_node_class_hook(ctx, "value")
161
+
162
+ if _is_subclass(ctx.cls.info, "vellum.workflows.nodes.bases.base.BaseNode"):
163
+ return self._base_node_class_hook(ctx)
164
+
165
+ if _is_subclass(ctx.cls.info, "vellum.workflows.workflows.base.BaseWorkflow"):
166
+ return self._base_workflow_class_hook(ctx)
167
+
168
+ if _is_subclass(ctx.cls.info, "vellum.workflows.outputs.base.BaseOutputs"):
169
+ return self._base_node_outputs_class_hook(ctx)
170
+
171
+ def _dynamic_output_node_class_hook(self, ctx: ClassDefContext, attribute_name: str) -> None:
172
+ """
173
+ We use this hook to properly annotate the Outputs class for Templating Node using the resolved type
174
+ of the TemplatingNode's class _OutputType generic.
175
+ """
176
+
177
+ templating_node_info = ctx.cls.info
178
+ templating_node_bases = ctx.cls.info.bases
179
+ if not templating_node_bases:
180
+ return
181
+ if not isinstance(templating_node_bases[0], Instance):
182
+ return
183
+
184
+ base_templating_args = templating_node_bases[0].args
185
+ base_templating_node = templating_node_bases[0].type
186
+ if not _is_subclass(base_templating_node, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
187
+ return
188
+
189
+ if len(base_templating_args) != 2:
190
+ return
191
+
192
+ base_templating_node_resolved_type = base_templating_args[1]
193
+ if isinstance(base_templating_node_resolved_type, AnyType):
194
+ base_templating_node_resolved_type = ctx.api.named_type("builtins.str")
195
+
196
+ base_templating_node_outputs = base_templating_node.names.get("Outputs")
197
+ if not base_templating_node_outputs:
198
+ return
199
+
200
+ current_templating_node_outputs = templating_node_info.names.get("Outputs")
201
+ if not current_templating_node_outputs:
202
+ templating_node_info.names["Outputs"] = base_templating_node_outputs.copy()
203
+ new_outputs_sym = templating_node_info.names["Outputs"].node
204
+ if isinstance(new_outputs_sym, TypeInfo):
205
+ result_sym = new_outputs_sym.names[attribute_name].node
206
+ if isinstance(result_sym, Var):
207
+ result_sym.type = base_templating_node_resolved_type
208
+
209
+ def _base_node_class_hook(self, ctx: ClassDefContext) -> None:
210
+ """
211
+ Special handling of BaseNode class definitions
212
+ """
213
+ self._redefine_class_attributes_with_descriptors(ctx)
214
+
215
+ def _base_node_outputs_class_hook(self, ctx: ClassDefContext) -> None:
216
+ """
217
+ Special handling of BaseNode.Outputs class definitions
218
+ """
219
+ self._redefine_class_attributes_with_descriptors(ctx)
220
+
221
+ def _redefine_class_attributes_with_descriptors(self, ctx: ClassDefContext) -> None:
222
+ """
223
+ Given a class definition, we want to redefine all of the class attributes to accept both
224
+ the original defined type, and the descriptor version of the type.
225
+ """
226
+
227
+ for sym in ctx.cls.info.names.values():
228
+ if not isinstance(sym.node, Var):
229
+ continue
230
+
231
+ type_ = sym.node.type
232
+ if not type_:
233
+ continue
234
+
235
+ sym.node.type = self._get_resolvable_type(
236
+ lambda fullname, types: ctx.api.named_type(fullname, types), type_
237
+ )
238
+
239
+ # Supports descriptors assigned to nested classes
240
+ type_resolver = lambda fullname, types: ctx.api.named_type(fullname, types) # noqa: E731
241
+ for assignment in ctx.cls.defs.body:
242
+ if not isinstance(assignment, AssignmentStmt):
243
+ continue
244
+
245
+ call_expr = assignment.rvalue
246
+ if not isinstance(call_expr, CallExpr):
247
+ continue
248
+
249
+ self._collect_descriptor_expressions(type_resolver, call_expr)
250
+
251
+ def _collect_descriptor_expressions(self, type_resolver: TypeResolver, call_expr: CallExpr) -> None:
252
+ for arg in call_expr.args:
253
+ if isinstance(arg, CallExpr):
254
+ self._collect_descriptor_expressions(type_resolver, arg)
255
+ continue
256
+
257
+ if not isinstance(arg, MemberExpr):
258
+ continue
259
+
260
+ is_arg_registered = False
261
+ for base_path, _, attribute_regex in DESCRIPTOR_PATHS:
262
+ if not re.match(attribute_regex, arg.name):
263
+ continue
264
+
265
+ if not isinstance(arg.expr, (NameExpr, MemberExpr)):
266
+ continue
267
+
268
+ if not isinstance(arg.expr.node, TypeInfo):
269
+ continue
270
+
271
+ if not _is_subclass(arg.expr.node, base_path):
272
+ continue
273
+
274
+ is_arg_registered = True
275
+ self._calls_with_nested_descriptor_expressions.add(call_expr)
276
+ self._nested_descriptor_expressions[arg] = type_resolver
277
+ break
278
+
279
+ if not is_arg_registered:
280
+ # Need to check node outputs that are inherited
281
+ if (
282
+ re.match(r"^[^_].*$", arg.name)
283
+ and isinstance(arg.expr, MemberExpr)
284
+ and isinstance(arg.expr.expr, NameExpr)
285
+ and isinstance(arg.expr.expr.node, TypeInfo)
286
+ and _is_subclass(arg.expr.expr.node, "vellum.workflows.nodes.bases.base.BaseNode")
287
+ and arg.expr.name == "Outputs"
288
+ ):
289
+ self._calls_with_nested_descriptor_expressions.add(call_expr)
290
+ self._nested_descriptor_expressions[arg] = type_resolver
291
+
292
+ def _base_workflow_class_hook(self, ctx: ClassDefContext) -> None:
293
+ """
294
+ Placeholder for any special type logic we want to add to the Workflow class.
295
+ """
296
+
297
+ pass
298
+
299
+ def get_function_signature_hook(self, fullname: str) -> Optional[Callable[[FunctionSigContext], FunctionLike]]:
300
+ """
301
+ This hook is used whenever we're calling a function and are type checking the signature. e.g. `f(a, b)`.
302
+
303
+ We use this to support nested objects that reference descriptors within a node. Class initialization
304
+ counts as a function call in mypy, so we want to support nested descriptors assigned to class instances
305
+ we don't control, like dataclasses.
306
+ """
307
+
308
+ return self._function_signature_hook
309
+
310
+ def _function_signature_hook(self, ctx: FunctionSigContext) -> FunctionLike:
311
+ if not isinstance(ctx.context, CallExpr):
312
+ return ctx.default_signature
313
+
314
+ if ctx.context not in self._calls_with_nested_descriptor_expressions:
315
+ return ctx.default_signature
316
+
317
+ old_arg_types = ctx.default_signature.arg_types
318
+ old_arg_names = ctx.default_signature.arg_names
319
+ old_arg_kinds = ctx.default_signature.arg_kinds
320
+ new_arg_types = []
321
+
322
+ new_arg_by_name = {
323
+ arg_name: ctx.context.args[arg_index] for arg_index, arg_name in enumerate(ctx.context.arg_names)
324
+ }
325
+
326
+ should_copy_new_signature = False
327
+
328
+ for arg_index, old_arg_type in enumerate(old_arg_types):
329
+ if arg_index >= len(old_arg_kinds) or arg_index >= len(old_arg_names):
330
+ new_arg_types.append(old_arg_type)
331
+ continue
332
+
333
+ old_arg_kind = old_arg_kinds[arg_index]
334
+ if old_arg_kind.is_named():
335
+ old_arg_name = old_arg_names[arg_index]
336
+ old_arg = new_arg_by_name.get(old_arg_name)
337
+ elif arg_index < len(ctx.context.args):
338
+ old_arg = ctx.context.args[arg_index]
339
+ else:
340
+ old_arg = None
341
+
342
+ if isinstance(old_arg, MemberExpr) and old_arg in self._nested_descriptor_expressions:
343
+ should_copy_new_signature = True
344
+
345
+ new_arg_types.append(
346
+ self._get_resolvable_type(
347
+ self._nested_descriptor_expressions[old_arg],
348
+ old_arg_type,
349
+ )
350
+ )
351
+ else:
352
+ new_arg_types.append(old_arg_type)
353
+
354
+ if not should_copy_new_signature:
355
+ return ctx.default_signature
356
+
357
+ return ctx.default_signature.copy_modified(
358
+ ret_type=ctx.default_signature.ret_type,
359
+ arg_types=new_arg_types,
360
+ )
361
+
362
+ def _get_resolvable_type(
363
+ self, get_named_type: Callable[[str, List[MypyType]], MypyType], type_: MypyType
364
+ ) -> MypyType:
365
+ if isinstance(type_, TypeAliasType) and type_.alias:
366
+ if type_.alias.fullname == "vellum.workflows.types.core.Json":
367
+ """
368
+ We want to avoid infinite recursion, so we just state that the descriptor can
369
+ just reference `Json` directly instead of each of the members.
370
+ """
371
+ return UnionType(
372
+ [
373
+ type_,
374
+ get_named_type("vellum.workflows.descriptors.base.BaseDescriptor", [type_]),
375
+ ]
376
+ )
377
+
378
+ """
379
+ Type Aliases expand to an actual type, so we want to keep drilling down.
380
+ Example: Foo = str
381
+ """
382
+ return self._get_resolvable_type(get_named_type, type_.alias.target)
383
+
384
+ if isinstance(type_, UnionType):
385
+ """
386
+ If a node attribute is referencing a union type, we want to accept a descriptor
387
+ pointing to any of the individual members.
388
+ """
389
+
390
+ return UnionType([self._get_resolvable_type(get_named_type, t) for t in type_.items])
391
+
392
+ if isinstance(type_, Instance) and type_.type.fullname == "builtins.dict":
393
+ """
394
+ If a node attribute is referencing a dict, we want to accept a descriptor pointing
395
+ to the dict itself or any of the values it maps to.
396
+ """
397
+
398
+ key_type = type_.args[0]
399
+ value_type = type_.args[1]
400
+ return get_named_type(
401
+ type_.type.fullname,
402
+ [
403
+ key_type,
404
+ self._get_resolvable_type(get_named_type, value_type),
405
+ ],
406
+ )
407
+
408
+ """
409
+ Otherwise by default, we want to accept a descriptor pointing to the type itself.
410
+ """
411
+ return UnionType(
412
+ [
413
+ type_,
414
+ get_named_type("vellum.workflows.descriptors.base.BaseDescriptor", [type_]),
415
+ ]
416
+ )
417
+
418
+ def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], MypyType]]:
419
+ return self._attribute_hook
420
+
421
+ def _attribute_hook(self, ctx: AttributeContext) -> MypyType:
422
+ if not isinstance(ctx.context, MemberExpr):
423
+ return ctx.default_attr_type
424
+ if not isinstance(ctx.context.expr, NameExpr) or ctx.context.expr.name != "self":
425
+ return ctx.default_attr_type
426
+
427
+ # TODO: ensure that `self` is a BaseNode
428
+ # https://app.shortcut.com/vellum/story/5531
429
+
430
+ return self._resolve_descriptor_type(ctx.default_attr_type)
431
+
432
+ def _resolve_descriptor_type(self, default_type: MypyType) -> MypyType:
433
+ if not isinstance(default_type, UnionType):
434
+ return default_type
435
+
436
+ non_descriptor_items = [
437
+ (
438
+ item.args[0]
439
+ if isinstance(item, Instance)
440
+ and _is_subclass(item.type, "vellum.workflows.descriptors.base.BaseDescriptor")
441
+ and len(item.args) > 0
442
+ else item
443
+ )
444
+ for item in default_type.items
445
+ ]
446
+
447
+ new_items = list(set(non_descriptor_items))
448
+
449
+ if len(new_items) == 0:
450
+ return default_type
451
+
452
+ if len(new_items) == 1:
453
+ return new_items[0]
454
+
455
+ return UnionType(
456
+ items=new_items,
457
+ )
458
+
459
+ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
460
+ """
461
+ This hook is used whenever we're calling a method from a class and are type checking the return type.
462
+ We use this to support special return types from our classes that isn't supported by default due to
463
+ the lack of nested class inheritance, among other python pitfalls.
464
+ """
465
+
466
+ if fullname.endswith(".run"):
467
+ return self._run_method_hook
468
+
469
+ if fullname.endswith(".stream"):
470
+ return self._stream_method_hook
471
+
472
+ return None
473
+
474
+ def _run_method_hook(self, ctx: MethodContext) -> MypyType:
475
+ """
476
+ We use this to target `Workflow.run()` so that the WorkflowExecutionFulfilledEvent is properly typed
477
+ using the `Outputs` class defined on the user-defined subclass of `Workflow`.
478
+ """
479
+
480
+ if not isinstance(ctx.default_return_type, TypeAliasType):
481
+ return ctx.default_return_type
482
+
483
+ alias = ctx.default_return_type.alias
484
+ if not alias:
485
+ return ctx.default_return_type
486
+
487
+ alias_target = alias.target
488
+ if not isinstance(alias_target, UnionType) or not alias_target.items:
489
+ return ctx.default_return_type
490
+
491
+ fulfilled_event = alias_target.items[0]
492
+ if not isinstance(fulfilled_event, Instance):
493
+ return ctx.default_return_type
494
+
495
+ if fulfilled_event.type.fullname != "vellum.workflows.events.workflow.WorkflowExecutionFulfilledEvent":
496
+ return ctx.default_return_type
497
+
498
+ outputs_node = self._get_outputs_node(ctx)
499
+ if not outputs_node:
500
+ return ctx.default_return_type
501
+
502
+ new_fulfilled_event = fulfilled_event.copy_modified(args=(Instance(outputs_node, []),))
503
+ return TypeAliasType(
504
+ alias=MypyTypeAlias(
505
+ target=UnionType(
506
+ items=[new_fulfilled_event] + alias_target.items[1:],
507
+ ),
508
+ fullname=alias.fullname,
509
+ line=alias.line,
510
+ column=alias.column,
511
+ ),
512
+ args=ctx.default_return_type.args,
513
+ line=ctx.default_return_type.line,
514
+ column=ctx.default_return_type.column,
515
+ )
516
+
517
+ def _stream_method_hook(self, ctx: MethodContext) -> MypyType:
518
+ """
519
+ We use this to target `Workflow.stream()` so that the WorkflowExecutionFulfilledEvent is properly typed
520
+ using the `Outputs` class defined on the user-defined subclass of `Workflow`.
521
+ """
522
+
523
+ if not isinstance(ctx.default_return_type, TypeAliasType):
524
+ return ctx.default_return_type
525
+
526
+ alias = ctx.default_return_type.alias
527
+ if not alias:
528
+ return ctx.default_return_type
529
+
530
+ alias_target = alias.target
531
+ if (
532
+ not isinstance(alias_target, Instance)
533
+ or not _is_subclass(alias_target.type, "typing.Iterator")
534
+ or not alias_target.args
535
+ ):
536
+ return ctx.default_return_type
537
+
538
+ union_alias = alias_target.args[0]
539
+ if not isinstance(union_alias, TypeAliasType) or not union_alias.alias:
540
+ return ctx.default_return_type
541
+
542
+ union_target = union_alias.alias.target
543
+ if not isinstance(union_target, UnionType) or not union_target.items:
544
+ return ctx.default_return_type
545
+
546
+ fulfilled_event_index = -1
547
+ fulfilled_event = None
548
+ for event_type_index, event_type in enumerate(union_target.items):
549
+ if not isinstance(event_type, Instance):
550
+ continue
551
+
552
+ if event_type.type.fullname != "vellum.workflows.events.workflow.WorkflowExecutionFulfilledEvent":
553
+ continue
554
+
555
+ fulfilled_event_index = event_type_index
556
+ fulfilled_event = event_type
557
+
558
+ if fulfilled_event_index == -1 or not fulfilled_event:
559
+ return ctx.default_return_type
560
+
561
+ outputs_node = self._get_outputs_node(ctx)
562
+ if not outputs_node:
563
+ return ctx.default_return_type
564
+
565
+ new_fulfilled_event = fulfilled_event.copy_modified(args=(Instance(outputs_node, []),))
566
+ return TypeAliasType(
567
+ alias=MypyTypeAlias(
568
+ target=alias_target.copy_modified(
569
+ args=[
570
+ TypeAliasType(
571
+ alias=MypyTypeAlias(
572
+ target=UnionType(
573
+ items=[
574
+ new_fulfilled_event if index == fulfilled_event_index else item
575
+ for index, item in enumerate(union_target.items)
576
+ ],
577
+ ),
578
+ fullname=union_alias.alias.fullname,
579
+ line=union_alias.alias.line,
580
+ column=union_alias.alias.column,
581
+ ),
582
+ args=union_alias.args,
583
+ line=union_alias.line,
584
+ column=union_alias.column,
585
+ )
586
+ ]
587
+ + list(alias_target.args[1:])
588
+ ),
589
+ fullname=alias.fullname,
590
+ line=alias.line,
591
+ column=alias.column,
592
+ ),
593
+ args=ctx.default_return_type.args,
594
+ line=ctx.default_return_type.line,
595
+ column=ctx.default_return_type.column,
596
+ )
597
+
598
+ def _get_outputs_node(self, ctx: MethodContext) -> Optional[TypeInfo]:
599
+ if not isinstance(ctx.context, CallExpr):
600
+ return None
601
+
602
+ if not isinstance(ctx.context.callee, MemberExpr):
603
+ return None
604
+
605
+ expr = ctx.context.callee.expr
606
+ instance = ctx.api.get_expression_type(expr)
607
+ if not isinstance(instance, Instance) or not _is_subclass(
608
+ instance.type, "vellum.workflows.workflows.base.BaseWorkflow"
609
+ ):
610
+ return None
611
+
612
+ outputs_node = instance.type.names.get("Outputs")
613
+
614
+ if (
615
+ not outputs_node
616
+ or not isinstance(outputs_node.node, TypeInfo)
617
+ or not _is_subclass(outputs_node.node, "vellum.workflows.outputs.base.BaseOutputs")
618
+ ):
619
+ return None
620
+
621
+ resolved_outputs_node = self._resolve_descriptors_in_outputs(outputs_node)
622
+
623
+ if not isinstance(resolved_outputs_node.node, TypeInfo):
624
+ return None
625
+
626
+ return resolved_outputs_node.node
627
+
628
+ def _resolve_descriptors_in_outputs(self, type_info: SymbolTableNode) -> SymbolTableNode:
629
+ new_type_info = type_info.copy()
630
+ if not isinstance(new_type_info.node, TypeInfo):
631
+ return new_type_info
632
+
633
+ for sym in new_type_info.node.names.values():
634
+ if isinstance(sym.node, Var):
635
+ descriptor_type = sym.node.type
636
+ if isinstance(descriptor_type, Instance) and _is_subclass(
637
+ descriptor_type.type, "vellum.workflows.descriptors.base.BaseDescriptor"
638
+ ):
639
+ args = descriptor_type.args
640
+ if args:
641
+ sym.node.type = args[0]
642
+
643
+ return new_type_info
8
644
 
9
645
 
10
646
  def plugin(version: str) -> Type[VellumMypyPlugin]: